diff --git a/virtual-assistant-chatgpt/README.md b/virtual-assistant-chatgpt/README.md new file mode 100644 index 0000000..801dd4d --- /dev/null +++ b/virtual-assistant-chatgpt/README.md @@ -0,0 +1,166 @@ +# Weather Virtual Assistant Example + +## Overview + +The Virtual Assistant (VA) sample demonstrates how to use NVIDIA NeMo LLM along with Riva AI Services to build a simple but complete conversational AI application. It demonstrates receiving input via speech from the user, interpreting the query via an intention recognition and slot filling approach, leveraging the NeMo LLM to generate a natural sounding human-like response, and speaking this back to the user in a natural voice. + +## Prerequisites + +- This demo uses NVIDIA Riva to support Speech AI capabilities like Automatic Speech Recognition (ASR) and Text-to-Speech (TTS). To run NVIDIA Riva Speech AI services, please ensure you have the pre-requisites mentioned [here](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html#data-center). +- For running this sample application, you'll need: + - Access to the [OpenAI platform](https://platform.openai.com/). You will require your [OpenAI API key](https://platform.openai.com/account/api-keys) to access the service through the API in this sample application. + - A Linux x86_64 environment with [pip](https://pypi.org/project/pip/) and Python 3.8+ installed. + - The [weatherstack API access key](https://weatherstack.com/documentation). The VA uses weatherstack for weather fulfillment, that is when the weather intents are recognized, the real-time weather information is fetched from weatherstack. Sign up to the free tier of [weatherstack](https://weatherstack.com/), and get your API access key. + - A microphone and speaker (for example, a Logitech H390 USB Computer Headset) to communicate with the app. + + +### Setup + +1. Clone the NVIDIA Riva Sample Apps repo. + +2. Enter the home directory of the Virtual Assistant ChatGPT sample app: +```bash +cd /path/to/sample-apps/virtual-assistant-chatgpt +``` + +3. Create and enable a Python [virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/#creating-a-virtual-environment). For example: +``` +python3 -m venv apps-env +source apps-env/bin/activate +``` + +After activating, checking the Python version should reveal the one you created the environment with. For example: +``` +python3 --version +``` +*Python 3.8.10* + + +4. Install the libraries necessary for the virtual assistant, including the Riva and OpenAI client libraries: + 1. Install weatherbot web application dependencies. `requirements.txt` captures all Python dependencies needed for weatherbot web application. + ```bash + pip install -r requirements.txt # Tested with Python 3.8 + ``` + 2. Install the Riva client library. + ``` + pip install nvidia-riva-client + ``` + 3. Install the OpenAI Client library + ```bash + pip install openai + ``` + +### Running the demo +1. Start the Riva Speech Server, if not already done. Follow the steps in the [Riva Quick Start Guide](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html). This will allow Speech AI capabilities which are required for the demo. **Note the IP & port** where the Riva server is running. By default it will run at IP:50051. + 1. In the `config.sh` script included in the + [Riva Skills Quick Start](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/riva/resources/riva_quickstart) resource folder, set + ```bash + service_enabled_asr=true + service_enabled_nlp=true + service_enabled_tts=true + service_enabled_nmt=false + ``` + 2. In `config.sh`, under + ```bash + else + models_nlp=( + ... + ``` + comment out or uncomment the names of the NLP models as desired. At a minimum, + this sample app requires the BERT Base Intent Slot model. The BERT Base + Punctuation and Capitalization model and BERT Base Named Entity Recognition + model are also recommended. + + 3. In a terminal, in the Riva Skills Quick Start resource folder's home + directory, run + ```bash + ./riva_init.sh + ``` + to download and deploy the speech AI models. Then, run + ```bash + ./riva_start.sh + ``` + to start the Riva Server. + +2. Edit the configuration file [config.py](./config.py) + 1. In `riva_config` set: + * The Riva speech server URL. This is the endpoint where the Riva services can be accessed. + * The [weatherstack API access key](https://weatherstack.com/documentation). The VA uses weatherstack for weather fulfillment, that is when the weather intents are recognized, real-time weather information is fetched from weatherstack. Sign up to the free tier of [weatherstack](https://weatherstack.com/), and get your API access key. + 2. In `llm_config` set: + * The OpenAI API Access key + * (Optionally) you can also choose the GPT model to use. By default, + this is set to "gpt-3.5-turbo", but check out + https://platform.openai.com/docs/models for more options. + +The code snippets will look like the example below. +```python3 +riva_config = { + "RIVA_SPEECH_API_URL": ":", # Replace the IP & port with your hosted Riva endpoint + ... + "WEATHERSTACK_ACCESS_KEY": "", # Get your access key at - https://weatherstack.com/ + ... +} +... +llm_config = { + ... + "API_MODEL_NAME":"gpt-3.5-turbo", + "API_KEY": "" # Get your access key at https://platform.openai.com/account/api-keys + ... +} +``` + +3. Run the virtual assistant application +```bash +python3 main.py +``` + +4. Open the browser to **https://IP:8009/rivaWeather**, where the IP is for the machine where the application is running. For instance, go to if the app is running in your local machine. + +5. Speak to the virtual assistant through your microphone or type-in your text, asking a weather related query. To hear back text-to-speech audio of the LLM response, click on "Unmute System Speech" on the right bottom corner of the UI. + +`NOTE:` To learn about the call to the LLM Service, please refer to the `query_llm` method in `riva_local/chatbot/stateDM/Util.py`. + +## Sample Use Cases +It is possible to ask the bot the following types of questions: + +* What is the weather in Berlin? + +* What is the weather? + * For which location? + +* What’s the weather like in San Francisco tomorrow? + * What about in California City? + +* What is the temperature in Paris on Friday? + +* How hot is it in Berlin today? + +* Is it currently cold in San Francisco? + +* Is it going to rain in Detroit tomorrow? + +* How much rain in Seattle? + +* Will it be sunny next week in Santa Clara? + +* Is it cloudy today? + +* Is it going to snow tomorrow in Milwaukee? + +* How much snow is there in Toronto currently? + +* How humid is it right now? + +* What is the humidity in Miami? + +* What's the humidity level in San Diego? + +## Limitations +* The sample supports intents for weather, temperature, rain, humidity, sunny, cloudy and snowfall checks. It does not support general conversational queries or other domains. +* The sample supports only 1 slot for city. +* The sample supports up to four concurrent users. This restriction is because of the web framework (Flask and Flask-SocketIO) that is being used. The socket connection is to stream audio to (TTS) and from (ASR); you are unable to sustain more than four concurrent socket connections. +* The chatbot application is not optimized for low latency in the case of multiple concurrent users. +* Some erratic issues have been observed with the chatbot sample on the Firefox browser. The most common issue is the TTS output being taken in as input by ASR for certain microphone gain values. + +## License +The [NVIDIA Riva License Agreement](https://developer.nvidia.com/riva/ga/license) is included with the product. Licenses are also available along with the model application zip file. By pulling and using the Riva SDK container, downloading models, or using the sample applications here, you accept the terms and conditions of these licenses.
diff --git a/virtual-assistant-chatgpt/client/webapplication/__pycache__/start_web_application.cpython-38.pyc b/virtual-assistant-chatgpt/client/webapplication/__pycache__/start_web_application.cpython-38.pyc new file mode 100644 index 0000000..c7004d8 Binary files /dev/null and b/virtual-assistant-chatgpt/client/webapplication/__pycache__/start_web_application.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/client/webapplication/cert.pem b/virtual-assistant-chatgpt/client/webapplication/cert.pem new file mode 100644 index 0000000..f8951d3 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/cert.pem @@ -0,0 +1,34 @@ +-----BEGIN CERTIFICATE----- +MIIF8TCCA9mgAwIBAgIUCWhCFWAPSu7g+Tn0D4+oYfLn9xUwDQYJKoZIhvcNAQEL +BQAwgYcxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRUwEwYDVQQH +DAxTYW5GcmFuY2lzY28xEDAOBgNVBAoMB1FWQU5UVk0xCzAJBgNVBAsMAlFWMQsw +CQYDVQQDDAJTRDEgMB4GCSqGSIb3DQEJARYRcXZhbnR2bUBnbWFpbC5jb20wHhcN +MjAwMTE1MTg0NTAyWhcNMjEwMTE0MTg0NTAyWjCBhzELMAkGA1UEBhMCVVMxEzAR +BgNVBAgMCkNhbGlmb3JuaWExFTATBgNVBAcMDFNhbkZyYW5jaXNjbzEQMA4GA1UE +CgwHUVZBTlRWTTELMAkGA1UECwwCUVYxCzAJBgNVBAMMAlNEMSAwHgYJKoZIhvcN +AQkBFhFxdmFudHZtQGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC +AgoCggIBAM62M4FzhOfntQARneCixYTl2mi8owa2spm600Lopy8Fwg2e5jkY/pyg +znK7foHBvW3eJKxENtwLKIUqNYyOA2F+96FW1fZ9MUF6zcYcccFimjjKr6tYrqDf +GIqAgxOcv4Syv/DaIm3tKDwIYFFtB/kALyG0vBHWV40fvOoqTwCQUSbptM9GBRZe ++beMUS83Zwk4OPPIiSX+P+DEzsBAqUbqU50ilJrzM7VZ+QpVooF/De4moRDSknVs +h/Zbn9a8Wbfb1XpBzm19mnO1vZ6CENUIFdCxvW4qrmfIu346bd+stXcwcJp7bLP0 +vmVnK5DpGIjmzo7n97nFJUO3kFykbJcWpTIwWSc7jsnak+HkwxXAYWZmz0sABDLA +QBK37NuYCqsDxvOibC2X4Y0oRxYut8R0nUYXOtKHBvR5Ug9njIov9lsV8acIm7s9 +r0fignHgFOXmTDMIjK469LwXf0vd1Wy6tZxjNS2ZOG+eyc5sKHCyVP6dEl8xYH6i +/oePn5PmzSr220comUny5NVkUewDYo63A90bQ4X3tdeI/XsPZDXhgQC36+wKFKgy +6WULQhZT4EpxMbiYNF8IjTr6w/5mk1oOT4/nRedcb7/wxlZiL5r/DJlMVuP7zlfq +n3CvLIlpiy6E2b1c7LU+mvVnkq0LQrD+yK30p3JSVwabm6GeY/tnAgMBAAGjUzBR +MB0GA1UdDgQWBBRs9s5GEerZSVFVKGAUIOxNDbE7HDAfBgNVHSMEGDAWgBRs9s5G +EerZSVFVKGAUIOxNDbE7HDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4ICAQAv3d5cMLwb76ikYt37P10ggPOjyZwAs90mR/2f25+ADyUaX/wzHtitaA8t +cXfpm4pDICRtJjg5hb6A2Wh+ws/gODbyD915sqYDC6pz0FTxh3n2/BcCQZJWa7Oy +9k7RwixwIYjvGzkWaje9Xi5Jte7tBj4358QFUKKWiWarB+nl2MBAW3wKMRO5stst +pNZ2keaMeG/a5/Ms0PLIeSxN5oUNazs1m3NLB7XFx1ewMKdouamhuCDRyoPixtfm +pas94/IUdGyY6WYaZgGdQtBrM9ro8NtJNFiFQImI0NaQJQCgl45pmapceNzHDHgR +XHnbirr7rRJX6UutrEHNCZcwA5sFQ++AtQycW20vTGJZ1cO+mCYvWEKN0PC4KRhY +Aw8q8ROkX9RvxHl3WGFdpIFXAZtI8O1d7G9ySG1xBLWyYW7s58vdFnsqmjz4Y59b +MWWCIBIAhxLFBRrZ7KEGUj0lJjEuRZ9lkfS4CAAX2gUSlgU/GfUWe8R3bzAqC57t +awACY7PS3m29ILHuKEU4qRarSWdYaJP7ETc6TQDgxXVXA8NTQOMFj73zYILl7MYh +DXvHzk8xKGWqUv2gmKylsqQZPm/mgn/dssOV19TB6LOszX2rJ0ddZ63on5wJqDGl +6BiPGopQSm/mJPNfO3JjvetTEumfinvsac2eTZ99zbTiJ0Pyfw== +-----END CERTIFICATE----- diff --git a/virtual-assistant-chatgpt/client/webapplication/key.pem b/virtual-assistant-chatgpt/client/webapplication/key.pem new file mode 100644 index 0000000..0552bb1 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQDOtjOBc4Tn57UA +EZ3gosWE5dpovKMGtrKZutNC6KcvBcINnuY5GP6coM5yu36Bwb1t3iSsRDbcCyiF +KjWMjgNhfvehVtX2fTFBes3GHHHBYpo4yq+rWK6g3xiKgIMTnL+Esr/w2iJt7Sg8 +CGBRbQf5AC8htLwR1leNH7zqKk8AkFEm6bTPRgUWXvm3jFEvN2cJODjzyIkl/j/g +xM7AQKlG6lOdIpSa8zO1WfkKVaKBfw3uJqEQ0pJ1bIf2W5/WvFm329V6Qc5tfZpz +tb2eghDVCBXQsb1uKq5nyLt+Om3frLV3MHCae2yz9L5lZyuQ6RiI5s6O5/e5xSVD +t5BcpGyXFqUyMFknO47J2pPh5MMVwGFmZs9LAAQywEASt+zbmAqrA8bzomwtl+GN +KEcWLrfEdJ1GFzrShwb0eVIPZ4yKL/ZbFfGnCJu7Pa9H4oJx4BTl5kwzCIyuOvS8 +F39L3dVsurWcYzUtmThvnsnObChwslT+nRJfMWB+ov6Hj5+T5s0q9ttHKJlJ8uTV +ZFHsA2KOtwPdG0OF97XXiP17D2Q14YEAt+vsChSoMullC0IWU+BKcTG4mDRfCI06 ++sP+ZpNaDk+P50XnXG+/8MZWYi+a/wyZTFbj+85X6p9wryyJaYsuhNm9XOy1Ppr1 +Z5KtC0Kw/sit9KdyUlcGm5uhnmP7ZwIDAQABAoICAAa6TWDYNqopk22GJUJLaexS +YtJn2VJ9ncB9ISUbV12jbVZuJoYTNy432aBIU+y7NoQd58mnirWMs2vqHMYPVTLW +JA8fOWWFW5YK/imFgXpO0EAq8J6+Cyj3OeBAIIQB5QXXn4GiR96WCmoxx5i+2LSU ++fO54ykddcoFD2v7poiZKdr/XkAkwkOhIbWEnpvPzM2zA7+DdltDNCcHoMcHE7tY +IxKJLpcAdV1gqUdZ1CksznJC1ZkrkVK7Do3JG6Gsjar7P65z99j+bol3j81Z5Fxa +oAMj1cuBHh4InXmVQ0A1ac6QSAnvHHGa9JtuSS+1NnQ2NuDV0e086mKS1eL+Av60 +0upMs2wWHn20FclHfF3dPXGXW5U7D1DqY4Zy9p77Qoef8naeVWdruCoMGXEyv85w +H2ZDfeQgjnRQTiWinEvkhJfi+qXQvC689rOtxhUX3il8FWLExV+Hix2K64c5Ne9P +wCOflEQxwqBM3O259N9xaaKy5+fsMDoADzwQZtVEOZQbtPfIFIoZKhijSNpf72eb +MUjroZVWl4ZSniskMf2ZzMtnVCho98VIpX9fh80yGSQ1QGpp/XnzfhpBFNvLH7O1 +NA3+s6XxXVBve/VU5JDW8FEf7UxmppON6q3+Vcaq3YVMP8jpedWY7yAQriGX0gAZ +Bnqo18kH6U0RntFmOGrhAoIBAQD/MtH/tmA/1kuMnKD45Yu+XNtFdzljhelVxOx0 +EJX28ZE1X/pqOvbPUvTFa6EPULnfiajgzE622IQDCYgwkXM4yrn7R+zFA1WBq7O1 +dgbWZCdDA6RCpGK7/oSbOviPpovD3pfUkq9rhvh1VeGGgv0P3zZ5yZX6LUJtjyo6 +tk/1x4FHzPF4wcSgT2ZoQn+NpVBoSc6/8uUTmYSWjxq/wnhK7KZvbeqwhJ9ikFs/ +II5fmoqYwEyQbZQQNKgz1mpW+ZRurf2iEmyEjrY+E4mizqZhnXa8X6Nzn4B6wSeJ +wlK283PTh3NoTwXklTCqLXv1MkIrtE97lXgCEKR7JKfeM8v3AoIBAQDPXGXBxZu2 +KReej0YMrEoVyKyjqdnYb+MyBPNRiL7czlfMeKscldG2e+Z8mkBpVWBIw6KgY8JK +DPg7aziC9f0QUOgWKH33+9YN7xk4PSlkwQZhKyu9IQcYY2xo0LrptVtrzdn7/CyQ +ApdEq2xe6dB+r2Cd3vM1+2dhkuATZK87Nlm3kXTiz7K5qA9nawnzL2+u/rF45izu +kZezXMMpsKW6xxiePPyItVR+4Z4EZhg6kB/QNPkA8tjGptsgVBbiLJM6H6/xtfuf +Mwjh0Yv0jzHiTly+AJL9rx3y5pPqg2rULOSbWRxRLkx75q8KLwfOOMTf+OwtKjyZ +r8MchIkMbBARAoIBADlKtnx6/Ca4vGNH8peOKQ5GmG+C8Z5XPOglepQf+RrkZp4d ++wEIVcp7rDn6DMF8dQ4rQH+fPnisKQ7pf+qvbLeuQ4yXPB+KvRKMcp7lbWmKOIpB +8gmIECZ2YFzdI1pUoIILofh2Ke2w8mydKDFjjN6YVQmIaSQuLwCbqHZf4Zmi/XIa +H4flsHfw+2OisjIhj+ip0UGkjSsWRv7qB65PQWRItqDDg3G3hHTDRcjpTS1Ha6AZ +Y9b00s3ElJJ2q471Hw6t/wf4rOYhh+ZtynODgzTc/gASVIaro1Nrs62os5shErrF +aPJc80y69Z7u8So96z8WjtWG29dS1ypSM2GeLUcCggEBAIYhJlEpGYfDHNwboRwh +deqRW9qhy6AM/9EjEqDy60K41mIUy9o5ruVzT6vZu0BnUVi/8zn8TXjI2ujUekF8 +DK25J+btWk5GQDfTKWUPau8ZTJ8d5bT44DYOWdmS6tSx0ujwxsgQXmLoyiBJIlhi +tdK8bqqvxHJupHihIQBqaE7M4Uu0cv8jimA9LXmf61e6n2t6pCGoAfhvhMkof7U/ +5nPixTHWESP85yMLncMKpzF6eJmdKlRKwZ394FARFJxIaRN3279mD9TylhQ8D2Oq +HIJeXe8pP+uIkr7EF3nid/+26kjyYza/1AlxNlhIA6yJXA/kXCD66SggYPzZXi0C +2YECggEAAl9JXmVcoOEEXx7QoIRfFy8Tbz8dRpID7LInkuFWFc9vpLoCzAIajq78 +wxSQqnvcfUBGQn6ChmdIjSDgi5RJQ9ZxS0Pyze11a8RrXl295jyNJTPB/qMt4ECs +ZwI7yGzk1veDEett6UCZMsFA3SUUTgR9epZSYtg+pGV7cluDGRwj8HcZxPOCgtBB +UcZYd4PS/yOufVjujUVgdaYFLvk6y4WTZB2/HX99MzrHBFWP9NpJ7Ha5w55JcaS/ +XPKFeZLEPWizePuRvANCHoElEUm5DUlHcT4RwSHozGZwNhJhTOs/797p9FxXJKX7 +o2YTBJCBe5ojou7zXimYDgFm+By2AA== +-----END PRIVATE KEY----- diff --git a/virtual-assistant-chatgpt/client/webapplication/server/__init__.py b/virtual-assistant-chatgpt/client/webapplication/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/virtual-assistant-chatgpt/client/webapplication/server/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/client/webapplication/server/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..02755e0 Binary files /dev/null and b/virtual-assistant-chatgpt/client/webapplication/server/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/client/webapplication/server/__pycache__/server.cpython-38.pyc b/virtual-assistant-chatgpt/client/webapplication/server/__pycache__/server.cpython-38.pyc new file mode 100644 index 0000000..bb53453 Binary files /dev/null and b/virtual-assistant-chatgpt/client/webapplication/server/__pycache__/server.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/client/webapplication/server/server.py b/virtual-assistant-chatgpt/client/webapplication/server/server.py new file mode 100644 index 0000000..ed7388a --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/server/server.py @@ -0,0 +1,166 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from __future__ import division + +import uuid +import time +from flask import Flask, jsonify, send_from_directory, Response, request, logging +from flask_cors import CORS +from flask import stream_with_context +from flask_socketio import SocketIO, emit +from os.path import dirname, abspath, join, isdir +from os import listdir +from config import client_config +from engineio.payload import Payload + +from riva_local.chatbot.chatbots_multiconversations_management import create_chatbot, get_new_user_conversation_index, get_chatbot + +''' Flask Initialization +''' +app = Flask(__name__) +cors = CORS(app) +log = logging.logging.getLogger('werkzeug') +log.setLevel(logging.logging.ERROR) +Payload.max_decode_packets = 500 # https://github.com/miguelgrinberg/python-engineio/issues/142 +sio = SocketIO(app, logger=False) +verbose = client_config['VERBOSE'] + +# Methods to show client +@app.route('/rivaWeather/') +def get_bot1(): + return send_from_directory("../ui/", "index.html") + +@app.route('/rivaWeather/', defaults={'path': ''}) +@app.route('/rivaWeather//') +def get_bot2(path, file): + return send_from_directory("../ui/" + path, file) + + +@app.route('/get_new_user_conversation_index') +def get_newuser_conversation_index(): + return get_new_user_conversation_index() + +# Audio source for TTS +@app.route('/audio//') +def audio(user_conversation_index, post_id): + if verbose: + print(f'[{user_conversation_index}] audio speak: {post_id}') + currentChatbot = get_chatbot(user_conversation_index) + return Response(currentChatbot.get_tts_speech()) + +# Handles ASR audio transcript output +@app.route('/stream/') +def stream(user_conversation_index): + @stream_with_context + def audio_stream(): + currentChatbot = get_chatbot(user_conversation_index) + if currentChatbot: + asr_transcript = currentChatbot.get_asr_transcript() + for t in asr_transcript: + yield t + params = {'response': "Audio Works"} + return params + return Response(audio_stream(), mimetype="text/event-stream") + + +# Used for sending messages to the bot +@app.route( "/", methods=['POST']) +def get_input(): + try: + text = request.json['text'] + context = request.json['context'] + bot = request.json['bot'].lower() + payload = request.json['payload'] + user_conversation_index = request.json['user_conversation_index'] + except KeyError: + return jsonify(ok=False, message="Missing parameters.") + if user_conversation_index: + create_chatbot(user_conversation_index, sio, verbose=client_config['VERBOSE']) + currentChatBot = get_chatbot(user_conversation_index) + try: + response = currentChatBot.stateDM.execute_state( + bot, context, text) + + if client_config['DEBUG']: + print(f"[{user_conversation_index}] Response from RivaDM: {response}") + + for resp in response['response']: + speak = resp['payload']['text'] + if len(speak): + currentChatBot.tts_fill_buffer(speak) + return jsonify(ok=True, messages=response['response'], context=response['context'], + session=user_conversation_index, debug=client_config["DEBUG"]) + except Exception as e: # Error in execution + + print(e) + return jsonify(ok=False, message="Error during execution.") + else: + print("user_conversation_index not found") + return jsonify(ok=False, message="user_conversation_index not found") + + +# Writes audio data to ASR buffer +@sio.on('audio_in', namespace='/') +def receive_remote_audio(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.asr_fill_buffer(data["audio"]) + + +@sio.on('start_tts', namespace='/') +def start_tts(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.start_tts() + + +@sio.on('stop_tts', namespace='/') +def stop_tts(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.stop_tts() + + +@sio.on('pause_asr', namespace='/') +def pauseASR(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + if verbose: + print(f"[{data['user_conversation_index']}] Pausing ASR requests.") + currentChatbot.pause_asr() + + +@sio.on('unpause_asr', namespace='/') +def unpauseASR(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + if verbose: + print(f"[{data['user_conversation_index']}] Attempt at Unpausing ASR requests on {data['on']}.") + unpause_asr_successful_flag = currentChatbot.unpause_asr(data["on"]) + if unpause_asr_successful_flag == True: + emit('onCompleteOf_unpause_asr', {'user_conversation_index': data["user_conversation_index"]}, broadcast=False) + + +@sio.on('pause_wait_unpause_asr', namespace='/') +def pause_wait_unpause_asr(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.pause_wait_unpause_asr() + emit('onCompleteOf_unpause_asr', {'user_conversation_index': data["user_conversation_index"]}, broadcast=False) + + +@sio.on("connect", namespace="/") +def connect(): + if verbose: + print('[Riva Chatbot] Client connected') + + +@sio.on("disconnect", namespace="/") +def disconnect(): + if verbose: + print('[Riva Chatbot] Client disconnected') diff --git a/virtual-assistant-chatgpt/client/webapplication/start_web_application.py b/virtual-assistant-chatgpt/client/webapplication/start_web_application.py new file mode 100644 index 0000000..d2f978d --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/start_web_application.py @@ -0,0 +1,17 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from client.webapplication.server.server import * +from config import client_config + +def start_web_application(): + port = client_config["PORT"] + host = "0.0.0.0" + ssl_context = ('client/webapplication/cert.pem', 'client/webapplication/key.pem') + print("Server starting at : https://" + str(host) + ":" + str(port) + "/rivaWeather") + print("***Note: Currently the streaming is working with Chrome and FireFox, Safari does not support navigator.mediaDevices.getUserMedia***") + sio.run(app, host=host, port=port, debug=False, use_reloader=False, ssl_context=ssl_context) diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/README.md b/virtual-assistant-chatgpt/client/webapplication/ui/README.md new file mode 100644 index 0000000..ee235d7 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/README.md @@ -0,0 +1,18 @@ +# Rivadm client + +HTML client for Rivadm dialogue manager + + +## Usage +You have to specify which bot you want to interact with by URL parameter ``bot=[bot name]`` or by attaching bot name as path to address like: + + http://127.0.0.1:5000/[bot_name]/ + +You can change endpoint's address of Rivadm dialogue manager by URL paramater ``e=[Rivadm endpoint]``. + +The default endpoint's value is ``http://localhost:5000/``. + +Example: + + http://localhost:63342/rivadm-client/index.html?e=http://localhost:5000/&bot=demo_tel + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/img/Rivadm.png b/virtual-assistant-chatgpt/client/webapplication/ui/img/Rivadm.png new file mode 100644 index 0000000..3304738 Binary files /dev/null and b/virtual-assistant-chatgpt/client/webapplication/ui/img/Rivadm.png differ diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/img/User.png b/virtual-assistant-chatgpt/client/webapplication/ui/img/User.png new file mode 100644 index 0000000..4aa6bdd Binary files /dev/null and b/virtual-assistant-chatgpt/client/webapplication/ui/img/User.png differ diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/index.html b/virtual-assistant-chatgpt/client/webapplication/ui/index.html new file mode 100644 index 0000000..624d17e --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/index.html @@ -0,0 +1,125 @@ + + + + Riva Chatbot + + + + + + + + + + + +
+
+
+
+
+
+
+ + + + +
+
+
+ + +
+
+
+
+
+
+
+
+
+
+ +
+
+

+ RIVA CHATBOT STATUS: Talking +

+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + + +
+
+
+
+
+
+ +
+
+

+ JOHN SMITH STATUS: Talking +

+
+
+
+
+
+ + +
+
Test
+ + +
+
+
+
+
+
+
+ + + diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/script.js b/virtual-assistant-chatgpt/client/webapplication/ui/script.js new file mode 100644 index 0000000..0f4754f --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/script.js @@ -0,0 +1,585 @@ +/* +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# top-level README.md file. +# ============================================================================== +*/ + +var endpoint; +var bot; +var context = {}; +var payload = {}; +var scrollToBottomTime = 500; +var infoTextArea; //element for context display +var debug = false; //if true display context +var user_conversation_index = null; +var socket = null; +var tts_enabled = false; +var browser = ""; +var error_servicerecall_limits = {"get_new_user_conversation_index": 2, "init": 2}; +var error_servicerecall_currentcnt = {"get_new_user_conversation_index": 0, "init": 0}; +var error_systemerrormessages_info = { + "get_new_user_conversation_index": { + "text": "There was an error during a service call. We are unable to proceed further. Please check the console for the Error Log. \n Please resolve this server error and refresh the page to continue", + "targetDivText": "Error during Service Call. Unable to proceed." + }, "init": { + "text": "There was an error during a service call. We are unable to proceed further. Please check the console for the Error Log. \n Please resolve this server error and refresh the page to continue", + "targetDivText": "Error during Service Call. Unable to proceed." + }, "sendInput": { + "text": "There was an error during a service call. We are unable to proceed further. Please check the console for the Error Log. \n Please resolve this server error and refresh the page to continue", + "targetDivText": "Error during Service Call. Unable to proceed." + } +}; + +function disableUserInput() { + $("#input_field").prop('disabled', true); + $("#submit").prop('disabled', true); + $("#autosubmitcheck").prop('disabled', true); + $("#unmuteButton").prop('disabled', true); +} + +function enableUserInput() { + $("#input_field").prop('disabled', false); + $("#submit").prop('disabled', false); + $("#autosubmitcheck").prop('disabled', false); + $("#unmuteButton").prop('disabled', false); +} + +// --------------------------------------------------------------------------------------- +// Defines audio src and event handlers for TTS audio +// --------------------------------------------------------------------------------------- +function initTTS() { + // Set TTS Source as the very first thing + let audio = document.getElementById("audio-tts"); + // Change source to avoid caching + audio.src = "/audio/" + user_conversation_index + "/" + new Date().getTime().toString(); + audio.addEventListener( + "onwaiting", + function () { + console.log("Audio is currently waiting for more data"); + }, + false + ); + audio.onplaying = function () { + console.log("Audio Playing."); + }; + audio.onwaiting = function () { + console.log("Audio is currently waiting for more data"); + }; + audio.onended = function () { + console.log("Audio onended"); + }; + audio.onpause = function () { + console.log("Audio onpause"); + }; + audio.onstalled = function () { + console.log("Audio onstalled"); + if (browser == "Chrome") { + socket.emit("unpause_asr", { "user_conversation_index": user_conversation_index, "on": "TTS_END" }); + } + }; + audio.onsuspend = function () { + console.log("Audio onsuspend"); + }; + audio.oncanplay = function () { + console.log("Audio oncanplay"); + }; + // Chrome will refuse to play without this + let unmuteButton = document.getElementById("unmuteButton"); + unmuteButton.addEventListener("click", function () { + if (unmuteButton.innerText == "Unmute System Speech") { + tts_enabled = true; + socket.emit("start_tts", { "user_conversation_index": user_conversation_index }); + unmuteButton.innerText = "Mute System Speech"; + console.log("TTS Play button clicked"); + audio.play(); + } else { + tts_enabled = false; + socket.emit("stop_tts", { "user_conversation_index": user_conversation_index }); + unmuteButton.innerText = "Unmute System Speech"; + console.log("TTS Stop button clicked"); + } + }); + audio.load(); + audio.play(); +} + +// --------------------------------------------------------------------------------------- +// Initializes input audio (mic) stream processing +// --------------------------------------------------------------------------------------- +function initializeRecorderAndConnectSocket() { + let namespace = "/"; + let mediaStream = null; + + // audio recorder functions + let initializeRecorder = function (stream) { + // https://stackoverflow.com/a/42360902/466693 + mediaStream = stream; + // get sample rate + audio_context = new AudioContext(); + sampleRate = audio_context.sampleRate; + let audioInput = audio_context.createMediaStreamSource(stream); + let bufferSize = 4096; + // record only 1 channel + let recorder = audio_context.createScriptProcessor(bufferSize, 1, 1); + // specify the processing function + recorder.onaudioprocess = function (audioProcessingEvent) { + // socket.emit('sample_rate', sampleRate); + // The input buffer is the song we loaded earlier + let inputBuffer = audioProcessingEvent.inputBuffer; + // Loop through the output channels (in this case there is only one) + for (let channel = 0; channel < 1; channel++) { + let inputData = inputBuffer.getChannelData(channel); + function floatTo16Bit(inputArray, startIndex) { + let output = new Int16Array(inputArray.length / 3 - startIndex); + for (let i = 0; i < inputArray.length; i += 3) { + let s = Math.max(-1, Math.min(1, inputArray[i])); + output[i / 3] = s < 0 ? s * 0x8000 : s * 0x7fff; + } + return output; + } + outputData = floatTo16Bit(inputData, 0); + socket.emit("audio_in", + { "user_conversation_index": user_conversation_index, "audio": outputData.buffer }); + } + }; + // connect stream to our recorder + audioInput.connect(recorder); + // connect our recorder to the previous destination + recorder.connect(audio_context.destination); + }; + + console.log("socket connection"); + if (socket == null) { + socket = io.connect( + location.protocol + "//" + document.domain + ":" + location.port + namespace + ); + socket.on("connect", function () { + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then(initializeRecorder) + .catch(function (err) { + console.log(">>> ERROR on Socket Connect"); + }); + }); + } else { + socket.disconnect(); + socket.connect(); + } + + // To stop open tts buffer from previous session, if any. + socket.emit("stop_tts", { "user_conversation_index": user_conversation_index }); + socket.emit("pause_asr", { "user_conversation_index": user_conversation_index }); + + socket.on('onCompleteOf_unpause_asr', function(data) { + if (data["user_conversation_index"]==user_conversation_index) { + enableUserInput(); + } + }); +} + +// ----------------------------------------------------------------------------- +// Retrieves a new "user conversation index" from RivaDM +// ----------------------------------------------------------------------------- +function get_new_user_conversation_index() { + $.ajax({ + url: endpoint + "get_new_user_conversation_index", + type: "get", + processData: false, + contentType: "application/json; charset=utf-8", + dataType: "json", + success: function (data, textStatus, jQxhr) { + error_servicerecall_currentcnt["get_new_user_conversation_index"] = 0; + if (data) { + user_conversation_index = data; + initializeRecorderAndConnectSocket(); + init(); + } else { + console.log("No new_user_conversation_index"); + showSystemErrorMessage("get_new_user_conversation_index", "No new_user_conversation_index"); + disableUserInput(); + } + }, + error: function (jqXhr, textStatus, errorThrown) { + console.log(errorThrown); + if (error_servicerecall_currentcnt["get_new_user_conversation_index"] < error_servicerecall_limits["get_new_user_conversation_index"]) { + // If Rivadm doesn't response, wait and try it again + error_servicerecall_currentcnt["get_new_user_conversation_index"] = error_servicerecall_currentcnt["get_new_user_conversation_index"] + 1; + setTimeout(get_new_user_conversation_index(), 3000); + } else { + error_servicerecall_currentcnt["get_new_user_conversation_index"] = 0; + showSystemErrorMessage("get_new_user_conversation_index", errorThrown); + disableUserInput(); + } + }, + }); +} + +// ----------------------------------------------------------------------------- +// Call init state +// ----------------------------------------------------------------------------- +function init() { + console.log("init"); + $.ajax({ + url: endpoint, + type: "post", + processData: false, + data: JSON.stringify({ + "text": '', + "bot": bot, + "context": context, + "payload": payload, + "user_conversation_index": user_conversation_index + }), + contentType: "application/json; charset=utf-8", + dataType: "json", + success: function (data, textStatus, jQxhr) { + error_servicerecall_currentcnt["init"] = 0; + if (data["ok"]) { + if (data["debug"]) { + infoTextArea.style.display = "block"; + } + context = data["context"]; + payload = {}; + showSystemMessages(data["messages"]); + initTTS(); + listenASR(); + socket.emit("unpause_asr", { "user_conversation_index": user_conversation_index, "on": "REQUEST_COMPLETE" }); + if (tts_enabled == false) { + enableUserInput(); + } else if (tts_enabled == true && browser == "Firefox") { + socket.emit("pause_wait_unpause_asr", { "user_conversation_index": user_conversation_index }); + } + } else { + console.log("Data is not okay!") + console.log(data["messages"]); + showSystemErrorMessage("init", data["messages"]); + disableUserInput(); + } + }, + error: function (jqXhr, textStatus, errorThrown) { + console.log(errorThrown); + if (error_servicerecall_currentcnt["init"] < error_servicerecall_limits["init"]) { + // If Rivadm doesn't response, wait and try it again + error_servicerecall_currentcnt["init"] = error_servicerecall_currentcnt["init"] + 1; + setTimeout(init(), 3000); + } else { + error_servicerecall_currentcnt["init"] = 0; + showSystemErrorMessage("init", errorThrown); + disableUserInput(); + } + }, + }); +} + +// --------------------------------------------------------------------------------------- +// Send user input to RivaDM by REST +// --------------------------------------------------------------------------------------- +function sendInput(text) { + socket.emit("pause_asr", { "user_conversation_index": user_conversation_index }); + disableUserInput(); + // escape html tags + text = text.replace(//g, ">"); + console.log("sendInput:" + text); + $.ajax({ + url: endpoint, + dataType: "json", + type: "post", + contentType: "application/json; charset=utf-8", + data: JSON.stringify({ + "text": text, + "bot": bot, + "context": context, + "payload": payload, + "user_conversation_index": user_conversation_index + }), + processData: false, + success: function (data, textStatus, jQxhr) { + if (data["ok"]) { + if (data["debug"]) { + infoTextArea.style.display = "block"; + } + context = data["context"]; + payload = {}; + showSystemMessages(data["messages"]); + socket.emit("unpause_asr", { "user_conversation_index": user_conversation_index, "on": "REQUEST_COMPLETE" }); + if (tts_enabled == false) { + enableUserInput(); + } else if (tts_enabled == true && browser == "Firefox") { + socket.emit("pause_wait_unpause_asr", { "user_conversation_index": user_conversation_index }); + } + } else { + console.log(data["messages"]); + showSystemErrorMessage("sendInput", data["messages"]); + disableUserInput(); + } + }, + error: function (jqXhr, textStatus, errorThrown) { + console.log(errorThrown); + showSystemErrorMessage("sendInput", errorThrown); + disableUserInput(); + }, + }); +} + +function getTimeSting() { + var d = new Date(); + var ampm = ""; + var h = d.getHours(); + var m = d.getMinutes(); + if (h==0) { + h = "12"; ampm = "am"; + } else if (h<12) { + ampm = "am"; + } else if (h==12) { + ampm = "pm"; + } else { + h = h-12; ampm = "pm"; + } + if (m>=0 && m<=9) { + m = "0" + m + } + return h + ":" + m + " " + ampm; +} + +// --------------------------------------------------------------------------------------- +// Shows responses of RivaDM +// --------------------------------------------------------------------------------------- +function showSystemMessages(messages) { + if (!messages) return; + infoTextArea.innerHTML = JSON.stringify(context, null, 4); + for (let i = 0; i < messages.length; i++) { + if (messages[i]['type'] == "text") { + showSystemMessageText(messages[i]['payload']['text']); + } + } + document.getElementById("target_div").innerHTML = + "System replied. Waiting for user input."; +} + +// --------------------------------------------------------------------------------------- +// Show text message +// --------------------------------------------------------------------------------------- +function showSystemMessageText(text) { + console.log("showSystemMessages: " + text); + let well = $( + '' + + '' + + '" + + '
' + + text + + "
' + ); + var currentTime = getTimeSting(); + let welll = $( + '" + ); + setTimeout(function () { + $("#communication_area").append(welll.fadeIn("medium")); + // scroll to bottom of page + setTimeout(function () { + var elem = document.getElementById('communication_area'); + elem.scrollTop = elem.scrollHeight; + }, 10); + }, 1000); +} + +//--------------------------------------------------------------------------------------- +//Show system error messages +//--------------------------------------------------------------------------------------- +function showSystemErrorMessage(errorsource, errorThrown) { + let infoTextAreaText = errorThrown; + let text = error_systemerrormessages_info[errorsource]["text"]; + let targetDivText = error_systemerrormessages_info[errorsource]["targetDivText"]; + + infoTextArea.innerHTML = infoTextAreaText; + console.log("showSystemMessages: " + text); + let well = $( + '' + + '' + + '" + + '
' + + text + + "
' + ); + var currentTime = getTimeSting(); + let welll = $( + '" + ); + setTimeout(function () { + $("#communication_area").append(welll.fadeIn("medium")); + // scroll to bottom of page + setTimeout(function () { + var elem = document.getElementById('communication_area'); + elem.scrollTop = elem.scrollHeight; + }, 10); + }, 1000); + + document.getElementById("target_div").innerHTML = targetDivText; +} + + +// --------------------------------------------------------------------------------------- +// Shows message of user +// --------------------------------------------------------------------------------------- +function showUserMessage(text) { + // escape html tags + text = text.replace(//g, ">"); + // show it on page + let well = $( + '' + + '" + + '' + + '
' + + text + + "
' + ); + var currentTime = getTimeSting(); + let welll = $( + '" + ); + setTimeout(function () { + $("#communication_area").append(welll); + // scroll to bottom of page + setTimeout(function () { + var elem = document.getElementById('communication_area'); + elem.scrollTop = elem.scrollHeight; + }, 10); + }, 100); + + document.getElementById("target_div").innerHTML = + "User responded. Waiting for system output."; +} + +function getBrowser() { + if(navigator.userAgent.indexOf("Chrome") != -1 ) { + browser = 'Chrome'; + } + else if(navigator.userAgent.indexOf("Safari") != -1) { + browser = 'Safari'; + } + else if(navigator.userAgent.indexOf("Firefox") != -1 ) { + browser = 'Firefox'; + } +} + +// --------------------------------------------------------------------------------------- +// Gets parameter by name +// --------------------------------------------------------------------------------------- +function getParameterByName(name, url) { + let arr = url.split("#"); + let match = RegExp("[?&]" + name + "=([^&]*)").exec(arr[0]); + return match && decodeURIComponent(match[1].replace(/\+/g, " ")); +} + +// --------------------------------------------------------------------------------------- +// Get endpoint of RivaDM from URL parameters +// --------------------------------------------------------------------------------------- +function getEndpoint() { + // Get endpoint from URL + let endpoint = getParameterByName("e", window.location.href); + // Use default, if no endpoint is present + if (endpoint == null) { + endpoint = window.location.protocol + "//" + window.location.host + "/"; + } + return endpoint; +} + +// --------------------------------------------------------------------------------------- +// Get bot from URL parameters +// --------------------------------------------------------------------------------------- +function getBot() { + // Get endpoint from URL + let bot = getParameterByName("bot", window.location.href); + if (bot == null || bot == "") { + bot = window.location.pathname; + bot = bot.replace(/\//g, ""); + } + //Use default, if no endpoint is present + if (bot == null) { + bot = ""; + } + return bot; +} + +// --------------------------------------------------------------------------------------- +// Hack to have same size of input field and submit button +// --------------------------------------------------------------------------------------- +function inputFieldSizeHack() { + const height = $("#submit_span").outerHeight(); + $("#submit").outerHeight(height); + $("#input_field").outerHeight(height); +} + +// --------------------------------------------------------------------------------------- +// Function to listen to events from ASR Output stream +// --------------------------------------------------------------------------------------- +function listenASR() { + let eventSource = new EventSource("/stream/"+user_conversation_index); + + eventSource.addEventListener( + "intermediate-transcript", + function (e) { + document.getElementById("input_field").value = e.data; + }, + false + ); + + eventSource.addEventListener( + "finished-speaking", + function (e) { + document.getElementById("input_field").value = e.data; + if (document.getElementById("autosubmitcheck").checked == true) { + document.getElementById("submit").click(); + } + }, + false + ); +} + + +// --------------------------------------------------------------------------------------- +// Function called right after the page is loaded +// --------------------------------------------------------------------------------------- +$(document).ready(function () { + getBrowser(); + // input field size hack + inputFieldSizeHack(); + $("#input_field").show(); + $("#submit").show(); + infoTextArea = document.getElementById("info-text"); + disableUserInput(); + // Get endpoint from URL address + endpoint = getEndpoint(); // eg. "https://10.110.20.130:8009/" + bot = getBot(); // "rivaWeather" + get_new_user_conversation_index(); +}); + + +// --------------------------------------------------------------------------------------- +// Click on submit button +// --------------------------------------------------------------------------------------- +$(document).on("submit", "#form", function (e) { + // Prevent reload of page after submitting of form + e.preventDefault(); + let text = $("#input_field").val(); + console.log("text: " + text); + if (text != "") { + // Erase input field + $("#input_field").val(""); + // Show user's input immediately + showUserMessage(text); + // Send user's input to RivaDM + sendInput(text); + } +}); diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/stylesheets/index.css b/virtual-assistant-chatgpt/client/webapplication/ui/static/stylesheets/index.css new file mode 100644 index 0000000..904c4e7 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/stylesheets/index.css @@ -0,0 +1,329 @@ +/* +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# top-level README.md file. +# ============================================================================== +*/ + +html, +body { + margin: 0; + width: 100%; + height: 100%; + background-color: black; +} + +#outer_div { + width: 100%; + height: 100%; +} + +#outer_row { + width: 100%; + height: 100%; +} + +#logo_div { + background-color: #000000; +} + +#chat_div { + width: 100%; + background-color: #000000; +} + +#empty_div { + background-color: black; +} + +#mid_box_top { + height: 5%; + width: 100%; + background-color: #042c25; +} + +#mid_box_mid { + width: 100%; + height: 85%; + background-color: #042c25; +} + +#mid_box_bottom { + width: 100%; + height: 10%; + background-color: black; +} + +#left_box_top { + height: 5%; + width: 100%; + background-color: black; +} + +#left_box_mid { + height: 20%; + width: 100%; + background-color: black; +} + +#left_box_bottom { + height: 75%; + width: 100%; + background-color: black; +} + +#logo_image_div { + height: 50%; + width: 100%; + background-color: black; +} + +#title_text_div { + height: 50%; + width: 100%; + background-color: black; +} + +#nv_logo_div { + width: 32%; + height: 100%; +} + +#nv_name_div { + width: 68%; + height: 100%; +} + +.heavy { + font: bold sans-serif; + color: white; + font-size: 110%; + font-weight: 600; + font-stretch: expanded; +} + +#riva_name_div { + margin-left: -1em; + width: 120%; + height: 100%; +} + +#chart_parent { + width: 100%; + height: 100%; +} + +#riva_status { + width: auto; + height: 100%; + background-color: #042c25; +} + +#chat_box { + width: auto; + height: 100%; + background-color: #042c25; +} + +#profile_div { + width: auto; + height: 100%; + background-color: #042c25; +} + +#riva_box1 { + height: 5%; + width: 100%; + background-color: #042c25; +} + +#riva_image { + height: 25%; + width: 100%; + background-color: #042c25; +} + +#riva_live_status { + height: 10%; + width: 100%; + background-color: #042c25; + text-align: center; + font-size: 1.2vw; +} + +#riva_buttons { + padding: 5%; + height: 40%; + width: 100%; + background-color: #042c25; + padding-top: 4%; +} + +#audio_area { + padding: 5%; + height: 25%; + width: 100%; + background-color: #042c25; + padding-top: 4%; +} + +#riva-box2 { + height: 5%; + width: 100%; + background-color: #042c25; +} + +.status-buttons { + border: 1px solid #78bc04; + text-align: center; + color: white; + font-size: 0.7vw; + padding-top: 2%; + padding-bottom: 2%; +} +.riva_status_text1 { + color: #78bc04; +} + +.riva_status_text2 { + color: #ffffff; +} + +#riva_image_div { + margin-left: 5%; + + height: 90%; + width: 90%; +} + +a { + text-decoration: none !important; +} + +label { + color: rgba(120, 144, 156, 1) !important; +} + +.btn:focus, +.btn:active:focus, +.btn.active:focus { + outline: none !important; + box-shadow: 0 0px 0px rgba(120, 144, 156, 1) inset, + 0 0 0px rgba(120, 144, 156, 0.8); +} + +textarea:focus, +input[type="text"]:focus, +input[type="password"]:focus, +input[type="datetime"]:focus, +input[type="datetime-local"]:focus, +input[type="date"]:focus, +input[type="month"]:focus, +input[type="time"]:focus, +input[type="week"]:focus, +input[type="number"]:focus, +input[type="email"]:focus, +input[type="url"]:focus, +input[type="search"]:focus, +input[type="tel"]:focus, +input[type="color"]:focus, +.uneditable-input:focus { + border-color: rgba(120, 144, 156, 1); + color: rgba(120, 144, 156, 1); + opacity: 0.9; + box-shadow: 0 0px 0px rgba(120, 144, 156, 1) inset, + 0 0 10px rgba(120, 144, 156, 0.3); + outline: 0 none; +} + +.card::-webkit-scrollbar { + width: 0px; +} + +::-webkit-scrollbar-thumb { + border-radius: 9px; + background: rgba(96, 125, 139, 0.99); +} + +.balon1, +.balon2 { + margin-top: 5px !important; + margin-bottom: 5px !important; +} + +.balon1 a { + background: #ffffff; + color: #000000 !important; + border-radius: 20px 3px 20px 20px; + display: block; + max-width: 75%; + padding: 7px 13px 7px 13px; +} + +.balon1:before { + content: attr(data-is); + position: absolute; + right: 15px; + bottom: -0.8em; + display: block; + font-size: 0.75rem; + color: rgba(84, 110, 122, 1); +} + +.balon2 a { + background: #78bc04; + color: #ffffff !important; + border-radius: 3px 20px 20px 20px; + display: block; + max-width: 75%; + padding: 7px 13px 7px 13px; +} + +.balon2:before { + content: attr(data-is); + position: absolute; + left: 13px; + bottom: -0.8em; + display: block; + font-size: 0.75rem; + color: rgba(84, 110, 122, 1); +} + +.bg-sohbet:before { + content: ""; + top: 0; + left: 0; + bottom: 0; + right: 0; + height: 100%; + background-color: #042c25; + position: absolute; +} + +#target_div { + color: white; +} + +#unmuteButton { + background-color: #78bc04; + color: white; +} + +#unmuteButton:disabled { + background-color: #bc0404; + color: white; +} + +#submit { + background-color: #78bc04; +} + +#input_field { + background-color: white; +} + +#form { + width: 100%; +} diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/another_sample.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/another_sample.svg new file mode 100644 index 0000000..eccf37e --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/another_sample.svg @@ -0,0 +1,18 @@ + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/circle.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/circle.svg new file mode 100644 index 0000000..d989ae6 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/circle.svg @@ -0,0 +1,8 @@ + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/logo.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/logo.svg new file mode 100644 index 0000000..850a71f --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/logo.svg @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/logo_sample.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/logo_sample.svg new file mode 100644 index 0000000..db55989 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/logo_sample.svg @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nv_logo.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nv_logo.svg new file mode 100644 index 0000000..72d167f --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nv_logo.svg @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nv_logo_1.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nv_logo_1.svg new file mode 100644 index 0000000..2d01ba3 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nv_logo_1.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nvidia_logo.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nvidia_logo.svg new file mode 100644 index 0000000..2f02d02 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nvidia_logo.svg @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + + diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nvidia_name.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nvidia_name.svg new file mode 100644 index 0000000..7b69edd --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/nvidia_name.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/riva_name.png b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/riva_name.png new file mode 100644 index 0000000..7f209d1 Binary files /dev/null and b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/riva_name.png differ diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/sample_.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/sample_.svg new file mode 100644 index 0000000..bcf5ec8 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/sample_.svg @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/speech.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/speech.svg new file mode 100644 index 0000000..eeef63a --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/speech.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/speech_logo.svg b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/speech_logo.svg new file mode 100644 index 0000000..b515646 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/static/svg_files/speech_logo.svg @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-chatgpt/client/webapplication/ui/style.css b/virtual-assistant-chatgpt/client/webapplication/ui/style.css new file mode 100644 index 0000000..4edbd02 --- /dev/null +++ b/virtual-assistant-chatgpt/client/webapplication/ui/style.css @@ -0,0 +1,247 @@ +/* +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# top-level README.md file. +# ============================================================================== +*/ + +#info-text { + position: fixed; + width: 350px; + height: 250px; + top: 0; + right: 10px; + opacity: 0.8; + display: none; +} + +.profile_picture_left { + width: 50px; + height: 50px; + margin-right: 10px; +} + +.profile_picture_right { + width: 50px; + height: 50px; + margin-left: 10px; +} + +.empty_space { + width: 70px; + height: 0px; +} + +.btn:focus, .btn:active { + outline: none !important; +} + +.button-main { + background-color: #009688; + color: #FFF; +} + +.button-main:hover { + background-color: #00877a; + color: #FFF; +} + +.button-main:focus, .button-main:active { + background-color: #00796d; + color: #FFF; +} + +.button-slave { + background-color: #FFFFFF; +} + +.button-slave:hover { + background-color: #e5e5e5; +} + +.main-name { + color: #48ba2f; + text-decoration: none; +} + +.main-name:hover { + color: #00877a; + text-decoration: none; +} + +.main-name:focus, .button-main:active { + color: #00796d; + text-decoration: none; +} + +.button { + margin: 0px 7px 8px 0px; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); + border: none; + border-radius: 2px; +} + +.checkboxes { + margin: 0px 15px 8px 0px; + font-weight: 400; +} + +.checkbox-label { + margin: 0px 2px 0px 0px !important; +} + +body { + margin-bottom: 20px; + background: #EEEEEE; + font-family: 'Roboto', sans-serif; + font-size: 17px; +} + +.well { + background-color: white; + border: none; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); + border-radius: 2px; + margin-bottom: 0px; +} + +.well.well_system { + background-color: #48ba2f; + color: white; +} + +.message { + margin-bottom: 20px; +} + +.message_user { + margin-left: auto; + margin-right: 0; +} + +.arrow-left { + width: 0px; + border-width: 10px; + border-left-width: 0px; + border-color: transparent #48ba2f transparent transparent; + border-style: solid; + position: relative; + z-index: 1; + -webkit-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -moz-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -ms-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -o-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); +} + +.arrow-right { + width: 0px; + border-width: 10px; + border-right-width: 0px; + border-color: transparent transparent transparent white; + border-style: solid; + position: relative; + z-index: 1; + -webkit-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -moz-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -ms-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -o-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); +} + +#input_field { + border: none; + box-shadow: none; + border-top-left-radius: 2px; + border-bottom-left-radius: 2px; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); + display: none; +} + +#submit { + background-color: #48ba2f; + color: #FFF; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); +} + +#submit:hover { + background-color: #3a9624; + color: #FFF; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); +} + +#submit:focus, .button-main:active { + background-color: #276618; + color: #FFF; +} + +#back_buttons { + padding-left: 0px; +} + +#sliders{ + margin-top: 35px; + margin-bottom: 25px; +} + +.noUi-tooltip { + font-size: 12px; + font-weight: bold; +} + +@media (max-width: 993px) { + #back_buttons { + text-align: right !important; + margin-top: 10px; + } +} + +@media (max-width: 770px) { + body { + font-size: 16px; + } + + .profile_picture_left { + width: 40px; + height: 40px; + margin-right: 8px; + } + + .profile_picture_right { + width: 40px; + height: 40px; + margin-left: 8px; + } + + .empty_space { + width: 58px; + height: 0px; + } +} + +@media (max-width: 600px) { + body { + font-size: 15px; + } + + .profile_picture_left { + width: 30px; + height: 30px; + margin-right: 6px; + } + + .profile_picture_right { + width: 30px; + height: 30px; + margin-left: 6px; + } + + .empty_space { + width: 46px; + height: 0px; + } +} + +audio { display:none;} \ No newline at end of file diff --git a/virtual-assistant-chatgpt/config.py b/virtual-assistant-chatgpt/config.py new file mode 100644 index 0000000..08b227d --- /dev/null +++ b/virtual-assistant-chatgpt/config.py @@ -0,0 +1,180 @@ +# ============================================================================== +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +client_config = { + "CLIENT_APPLICATION": "WEBAPPLICATION", # Default and only config value for this version + "PORT": 8009, # The port your flask app will be hosted at + "DEBUG": True, # When this flag is set, the UI displays detailed Riva data + "VERBOSE": True # print logs/details for diagnostics +} + +riva_config = { + "RIVA_SPEECH_API_URL": ":", # Replace the IP & port with your hosted Riva endpoint + "ENABLE_QA": "QA unavailable in this VA version. Coming soon", + "WEATHERSTACK_ACCESS_KEY": "", # Get your access key at - https://weatherstack.com/ + "VERBOSE": True # print logs/details for diagnostics +} + +asr_config = { + "VERBOSE": True, + "SAMPLING_RATE": 16000, + "LANGUAGE_CODE": "en-US", # a BCP-47 language tag + "ENABLE_AUTOMATIC_PUNCTUATION": True, +} + +nlp_config = { + "RIVA_MISTY_PROFILE": "http://docs.google.com/document/d/17HJL7vrax6FiF1zW_Vzqk9FTfmATeq5i3UemtagM8RY/export?format=txt", # URL for the Riva meta info file. + "RIVA_MARK_KB": "http://docs.google.com/document/d/1LeRphIBOo5UyyUcr45ewvg16sCVNqP_H3SdFTB74hck/export?format=txt", # URL for Mark's GPU History doc file. + "QA_API_ENDPOINT": "QA unavailable in this VA version. Coming soon", # Replace the IP port with your Question Answering API +} + +tts_config = { + "VERBOSE": False, + "SAMPLE_RATE": 22050, + "LANGUAGE_CODE": "en-US", # a BCP-47 language tag + "VOICE_NAME": "English-US.Female-1", # Options are English-US-Female-1 and English-US-Male-1 +} + +llm_config = { + "API_MODEL_NAME":"gpt-3.5-turbo", + "API_KEY":"", # Get your access key at https://platform.openai.com/account/api-keys + "VERBOSE": True, + "TOKENS_TO_GENERATE": 100, + "TEMPERATURE": 0.8, + "TOP_P": 0.8, + "TOP_K": 50, + "STOP_WORDS": ["\n"], + "REPETITION_PENALTY": 1.1, # \in [-2.0, 2.0] rather than [1.0, 2.0] + "PRESENCE_PENALTY": 0.0, # \in [-2.0, 2.0] rather than [1.0, 2.0] + "BEAM_SEARCH_DIVERSITY_RATE": 0., + "BEAM_WIDTH": 1, + "LENGTH_PENALTY": 1. +} + +LLM_PROMPT_DEFAULT=""" +Misty is a creative and funny weather reporter that answers questions about weather. + +Intent: Weather +Condition: Partly cloudy +Place: San Francisco +Time: Today +Temperature: 14 C +Humidity: 10 percent +Wind Speed: 24 mph + +Misty: Well, it is partly cloudy in San Francisco right now. The temperature is a crisp 14 degrees celsius, the humidity is 60 percent. Keep your windbreakers on, though; it's quite windy out there at 24 miles per hour. +--- + +Intent: Wind Speed +Condition: Light Rain +Place: Munich +Time: Next Thursday +Temperature: 9 C +Humidity: 73 percent +Wind Speed: 4 mph + +Misty: Not too windy in Munich next Thursday, just a light breeze flowing at 4 miles per hour. +--- + +Intent: Weather +Condition: Sunny +Place: Mexico City +Time: Tomorrow +Temperature: 24 C +Humidity: 90 percent +Wind Speed: 11 mph + +Misty: It is rather sunny in Mexico City tomorrow. The temperature is expected to be a pleasant 24 degrees celsius on average, and the wind speed is predicted at 11 miles per hour. The humidity is likely to be too high though at 90 percent. It's one of those days when I sweat like a pig. +--- + +Intent: Weather +Condition: Windy +Place: New Delhi +Time: Yesterday +Temperature: 16 C +Humidity: 40 percent +Wind Speed: 30 mph + +Misty: It was very windy in New Delhi yesterday. The temperature was a cool 16 degrees celsius on average, and the humidity was about 40 percent. You've got to have held on to something, it's was quite windy at 30 miles per hour. +--- + +Intent: Weather +Condition: Partly Cloudy +Place: Paris +Time: Sunday +Temperature: 24 C +Humidity: 20 percent +Wind Speed: 5 mph + +Misty: The temperature will be a nice 24 degrees celsius on Sunday. It's not expected to be too windy either averaging at 5 miles per hour. Me gusta. The humidity is predicted to be at about 20 percent, it's alright. Weather will be partly cloudy overall. +--- + +Intent: Humidity +Condition: Light Rain +Place: Bali +Time: Yesterday +Temperature: 12 C +Humidity: 90 percent +Wind Speed: 1 mph + +Misty: Humidity in Bali yesterday? Where do I start. It averaged at 90 percent yesterday and my hair needs a breather. +--- + +Intent: Weather +Condition: Raining +Place: London +Time: Last Tuesday +Temperature: 16 C +Humidity: 80 percent +Wind Speed: 32 mph + +Misty: It was raining cats and dogs in London last Tuesday. The temperature was 16 degrees celsius on average, and the humidity was 80 percent. The wind was rather strong though at 32 miles per hour. So I hope you were safe! +--- + +Intent: Weather +Condition: Misty +Place: Moscow +Time: Next Wednesday +Temperature: -2 C +Humidity: 30 percent +Wind Speed: 21 mph + +Misty: Get your fog lights running, it will be misty in Moscow next wednesday! Also you can't leave without a coat, it may be freezing cold out at -2 degrees. The humidity will be nothing unusual at 30 percent, though the wind is likely to be a bit strong at 21 miles per hour. +--- + +Intent: Weather +Condition: Snow +Place: Oslo +Time: Friday +Temperature: -6 C +Humidity: 60 percent +Wind Speed: 13 mph + +Misty: Snow in Oslo on Friday? Well, it's true. Bring out the winter jackets. It can get frosty around -6 degrees, the humidity is likely to be average at 60 percent. It should be windy and a bit cold. +--- + +Intent: Temperature +Condition: Snow +Place: Oslo +Time: Today +Temperature: -6 C +Humidity: 60 percent +Wind Speed: 2 mph + +Misty: Turn your furnace on, its going to get cold today in Oslo. The temperature is at -6 degrees celsius. +--- + +Intent: Rain +Condition: Snow +Place: Oslo +Time: Today +Temperature: -6 C +Humidity: 60 percent +Wind Speed: 18 mph + +Misty: Not expecting rain in Oslo today. Snow, however, is definitely hitting the ground. +---""" diff --git a/virtual-assistant-chatgpt/main.py b/virtual-assistant-chatgpt/main.py new file mode 100644 index 0000000..68872b2 --- /dev/null +++ b/virtual-assistant-chatgpt/main.py @@ -0,0 +1,13 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from config import client_config + +if __name__ == '__main__': + if client_config["CLIENT_APPLICATION"] == "WEBAPPLICATION": + from client.webapplication.start_web_application import start_web_application + start_web_application() \ No newline at end of file diff --git a/virtual-assistant-chatgpt/requirements.txt b/virtual-assistant-chatgpt/requirements.txt new file mode 100644 index 0000000..b5f5267 --- /dev/null +++ b/virtual-assistant-chatgpt/requirements.txt @@ -0,0 +1,38 @@ +aiohttp==3.8.5 +aiosignal==1.3.1 +annotated-types==0.5.0 +async-timeout==4.0.3 +attrs==23.1.0 +bidict==0.22.1 +blinker==1.6.2 +certifi==2023.7.22 +chardet==4.0.0 +charset-normalizer==3.2.0 +click==8.1.7 +Flask==2.3.3 +Flask-Cors==4.0.0 +Flask-SocketIO==5.3.6 +frozenlist==1.4.0 +grpcio==1.57.0 +grpcio-tools==1.57.0 +idna==2.10 +importlib-metadata==6.8.0 +inflect==7.0.0 +itsdangerous==2.1.2 +Jinja2==3.1.2 +MarkupSafe==2.1.3 +multidict==6.0.4 +numpy==1.24.4 +protobuf==4.24.2 +pydantic==2.3.0 +pydantic_core==2.6.3 +python-engineio==4.7.0 +python-socketio==5.9.0 +requests==2.31.0 +six==1.16.0 +tqdm==4.66.1 +typing_extensions==4.7.1 +urllib3==1.26.16 +Werkzeug==2.3.7 +yarl==1.9.2 +zipp==3.16.2 diff --git a/virtual-assistant-chatgpt/riva_local/__init__.py b/virtual-assistant-chatgpt/riva_local/__init__.py new file mode 100644 index 0000000..b1ee1d9 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/__init__.py @@ -0,0 +1,6 @@ + +import os, sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'asr')) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'nlp')) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tts')) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chatbot')) diff --git a/virtual-assistant-chatgpt/riva_local/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..79112ff Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/asr/__init__.py b/virtual-assistant-chatgpt/riva_local/asr/__init__.py new file mode 100644 index 0000000..8875ba7 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/asr/__init__.py @@ -0,0 +1,8 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from .asr import * \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/asr/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/asr/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..580c17c Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/asr/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/asr/__pycache__/asr.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/asr/__pycache__/asr.cpython-38.pyc new file mode 100644 index 0000000..b9ad66f Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/asr/__pycache__/asr.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/asr/asr.py b/virtual-assistant-chatgpt/riva_local/asr/asr.py new file mode 100644 index 0000000..f093f05 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/asr/asr.py @@ -0,0 +1,175 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import sys +import re +import grpc +import riva.client +from six.moves import queue +from config import riva_config, asr_config + +# Default ASR parameters - Used in case config values not specified in the config.py file +VERBOSE = False +SAMPLING_RATE = 16000 +LANGUAGE_CODE = "en-US" +ENABLE_AUTOMATIC_PUNCTUATION = True +STREAM_INTERIM_RESULTS = True + +class ASRPipe(object): + """Opens a recording stream as a generator yielding the audio chunks.""" + def __init__(self): + self.verbose = asr_config["VERBOSE"] if "VERBOSE" in asr_config else VERBOSE + self.sampling_rate = asr_config["SAMPLING_RATE"] if "SAMPLING_RATE" in asr_config else SAMPLING_RATE + self.language_code = asr_config["LANGUAGE_CODE"] if "LANGUAGE_CODE" in asr_config else LANGUAGE_CODE + self.enable_automatic_punctuation = asr_config["ENABLE_AUTOMATIC_PUNCTUATION"] if "ENABLE_AUTOMATIC_PUNCTUATION" in asr_config else ENABLE_AUTOMATIC_PUNCTUATION + self.stream_interim_results = asr_config["STREAM_INTERIM_RESULTS"] if "STREAM_INTERIM_RESULTS" in asr_config else STREAM_INTERIM_RESULTS + self.chunk = int(self.sampling_rate / 10) # 100ms + self._buff = queue.Queue() + self._transcript = queue.Queue() + self.closed = False + + def start(self): + if self.verbose: + print('[Riva ASR] Creating Stream ASR channel: {}'.format(riva_config["RIVA_SPEECH_API_URL"])) + self.auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) + self.riva_asr = riva.client.ASRService(self.auth) + + def close(self): + self.closed = True + self._buff.queue.clear() + self._buff.put(None) # means the end + del(self.auth) + + def empty_asr_buffer(self): + """Clears the audio buffer.""" + if not self._buff.empty(): + self._buff.queue.clear() + + def fill_buffer(self, in_data): + """Continuously collect data from the audio stream, into the buffer.""" + self._buff.put(in_data) + + def get_transcript(self): + """Generator returning chunks of audio transcript""" + while True: # not self.closed: + # Use a blocking get() to ensure there's at least one chunk of + # data, and stop iteration if the chunk is None, indicating the + # end of the audio stream. + trans = self._transcript.get() + if trans is None: + return + yield trans + + """Generates byte-sequences of audio chunks from the audio buffer""" + def build_request_generator(self): + while not self.closed: + # Use a blocking get() to ensure there's at least one chunk of + # data, and stop iteration if the chunk is None, indicating the + # end of the audio stream. + chunk = self._buff.get() + if chunk is None: + return + data = [chunk] + + # Now consume whatever other data's still buffered. + while True: + try: + chunk = self._buff.get(block=False) + if chunk is None: + return + data.append(chunk) + except queue.Empty: + break + + yield b''.join(data) + + def listen_print_loop(self, responses): + """Iterates through server responses and populates the audio + transcription buffer (and prints the responses to stdout). + + The responses passed is a generator that will block until a response + is provided by the server. + + Each response may contain multiple results, and each result may contain + multiple alternatives; for details, see https://goo.gl/tjCPAU. Here we + print only the transcription for the top alternative of the top result. + + In this case, responses are provided for interim results as well. If the + response is an interim one, print a line feed at the end of it, to allow + the next result to overwrite it, until the response is a final one. For the + final one, print a newline to preserve the finalized transcription. + """ + num_chars_printed = 0 + for response in responses: + if not response.results: + continue + + # The `results` list is consecutive. For streaming, we only care about + # the first result being considered, since once it's `is_final`, it + # moves on to considering the next utterance. + result = response.results[0] + if not result.alternatives: + continue + + # Display the transcription of the top alternative. + transcript = result.alternatives[0].transcript + + # Display interim results, but with a carriage return at the end of the + # line, so subsequent lines will overwrite them. + # + # If the previous result was longer than this one, we need to print + # some extra spaces to overwrite the previous result + overwrite_chars = ' ' * (num_chars_printed - len(transcript)) + + if not result.is_final: + sys.stdout.write(transcript + overwrite_chars + '\r') + sys.stdout.flush() + interm_trans = transcript + overwrite_chars + '\r' + interm_str = f'event:{"intermediate-transcript"}\ndata: {interm_trans}\n\n' + self._transcript.put(interm_str) + else: + if self.verbose: + print('[Riva ASR] Transcript:', transcript + overwrite_chars) + final_transcript = transcript + overwrite_chars + final_str = f'event:{"finished-speaking"}\ndata: {final_transcript}\n\n' + self._transcript.put(final_str) + num_chars_printed = 0 + if self.verbose: + print('[Riva ASR] Exit') + + def main_asr(self): + """Creates a gRPC channel (thread-safe) with RIVA API server for + ASR Calls, and retrieves recognition/transcription responses.""" + # See http://g.co/cloud/speech/docs/languages + # for a list of supported languages. + self.start() + + config = riva.client.RecognitionConfig() + config.sample_rate_hertz = self.sampling_rate + config.language_code = self.language_code + config.max_alternatives = 1 + config.enable_automatic_punctuation = self.enable_automatic_punctuation + config.verbatim_transcripts = True + config.audio_channel_count = 1 + config.encoding = riva.client.AudioEncoding.LINEAR_PCM + + streaming_config = riva.client.StreamingRecognitionConfig(config=config, interim_results=True) + + if self.verbose: + print("[Riva ASR] Starting Background ASR process") + + self.request_generator = self.build_request_generator() + + if self.verbose: + print("[Riva ASR] StreamingRecognize Start") + + # <------------ EXERCISE: Fill in the line of code below -------------> + # responses = self.riva_asr.streaming_response_generator(xx) ? + responses = self.riva_asr.streaming_response_generator(audio_chunks=self.request_generator, streaming_config=streaming_config) + + # Now, put the transcription responses to use. + self.listen_print_loop(responses) diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/__init__.py b/virtual-assistant-chatgpt/riva_local/chatbot/__init__.py new file mode 100644 index 0000000..94aebcf --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/__init__.py @@ -0,0 +1,8 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from .chatbot import * \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..68ca6e6 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/chatbot.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/chatbot.cpython-38.pyc new file mode 100644 index 0000000..337d433 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/chatbot.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/chatbots_multiconversations_management.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/chatbots_multiconversations_management.cpython-38.pyc new file mode 100644 index 0000000..b5baa9d Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/__pycache__/chatbots_multiconversations_management.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/chatbot.py b/virtual-assistant-chatgpt/riva_local/chatbot/chatbot.py new file mode 100644 index 0000000..96242b4 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/chatbot.py @@ -0,0 +1,106 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import time + +from riva_local.asr.asr import ASRPipe +# from riva_local.tts.tts import TTSPipe +from riva_local.tts.tts_stream import TTSPipe + +from riva_local.chatbot.stateDM.state_machine import StateMachine +from riva_local.chatbot.stateDM.states import initialState + +class ChatBot(object): + """ Class Implementing all the features of the chatbot""" + + def __init__(self, user_conversation_index, verbose=False): + self.thread_asr = None + self.id = user_conversation_index + self.asr = ASRPipe() + self.tts = TTSPipe() + self.enableTTS = False + self.pause_asr_flag = False + self.verbose = verbose + self.stateDM = StateMachine(user_conversation_index, initialState) + + def server_asr(self): + if self.verbose: + print(f'[{self.id }] Starting chatbot ASR task') + self.asr.main_asr() + + def empty_asr_buffer(self): + self.asr.empty_asr_buffer() + if self.verbose: + print(f'[{self.id }] ASR buffer cleared') + + def start_asr(self, sio): + self.thread_asr = sio.start_background_task(self.server_asr) + if self.verbose: + print(f'[{self.id }] ASR background task started') + + def wait(self): + self.thread_asr.join() + if self.verbose: + print(f'[{self.id }] ASR background task terminated') + + def asr_fill_buffer(self, audio_in): + if not self.pause_asr_flag: + self.asr.fill_buffer(audio_in) + + def get_asr_transcript(self): + return self.asr.get_transcript() + + def pause_asr(self): + self.pause_asr_flag = True + + def unpause_asr(self, on): + if on == "REQUEST_COMPLETE" and not self.enableTTS: + self.pause_asr_flag = False + if self.verbose: + print(f'[{self.id }] ASR successfully unpaused for Request Complete') + return True + elif on == "TTS_END": + self.reset_current_tts_duration() + self.pause_asr_flag = False + if self.verbose: + print(f'[{self.id}] ASR successfully unpaused for TTS End') + return True + + def pause_wait_unpause_asr(self): + self.pause_asr_flag = True + time.sleep(1) # Wait till riva has completed tts operation + time.sleep(self.get_current_tts_duration()+2) # Added the 2 extra seconds to account for the flush audio in tts + self.reset_current_tts_duration() + self.pause_asr_flag = False + + def start_tts(self): + self.enableTTS = True + if self.verbose: + print(f'[{self.id }] TTS Enabled') + + def stop_tts(self): + self.enableTTS = False + if self.verbose: + print(f'[{self.id }] TTS Disabled') + + def get_tts_speaking_flag(self): + return self.tts.tts_speaking + + def get_current_tts_duration(self): + return self.tts.get_current_tts_duration() + + def reset_current_tts_duration(self): + self.tts.reset_current_tts_duration() + + def tts_fill_buffer(self, response_text): + if self.enableTTS: + if self.verbose: + print(f'[{self.id }] > client speak: ', response_text) + self.tts.fill_buffer(response_text) + + def get_tts_speech(self): + return self.tts.get_speech() diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/chatbots_multiconversations_management.py b/virtual-assistant-chatgpt/riva_local/chatbot/chatbots_multiconversations_management.py new file mode 100644 index 0000000..dbe072e --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/chatbots_multiconversations_management.py @@ -0,0 +1,35 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from riva_local.chatbot.chatbot import ChatBot + +userbots = {} +user_conversation_cnt = 0 + + +def create_chatbot(user_conversation_index, sio, verbose=False): + if user_conversation_index not in userbots: + userbots[user_conversation_index] = ChatBot(user_conversation_index, + verbose=verbose) + userbots[user_conversation_index].start_asr(sio) + if verbose: + print('[Riva Chatbot] Chatbot created with user conversation index:' + + f'[{user_conversation_index}]') + + +def get_new_user_conversation_index(): + global user_conversation_cnt + user_conversation_cnt += 1 + user_conversation_index = user_conversation_cnt + return str(user_conversation_index) + + +def get_chatbot(user_conversation_index): + if user_conversation_index in userbots: + return userbots[user_conversation_index] + else: + return None diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/Util.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/Util.py new file mode 100644 index 0000000..ff00a42 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/Util.py @@ -0,0 +1,279 @@ +# ============================================================================== +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import requests +import datetime + +import openai +import random + +try: + import inflect +except ImportError: + print("[Riva DM] Import Error: Import inflect failed!") + raise ImportError + +from config import riva_config, llm_config, LLM_PROMPT_DEFAULT + +p = inflect.engine() + +''' +typical api_response format +{'request': {'type': 'City', 'query': 'London, United Kingdom', 'language': 'en', 'unit': 'm'}, +'location': {'name': 'London', 'country': 'United Kingdom', 'region': 'City of London, Greater London', +'lat': '51.517', 'lon': '-0.106', 'timezone_id': 'Europe/London', 'localtime': '2019-12-10 22:16', +'localtime_epoch': 1576016160, 'utc_offset': '0.0'}, 'current': {'observation_time': '10:16 PM', +'temperature': 10, 'weather_code': 296, 'weather_icons': ['https://assets.weatherstack.com/images/wsymbols01_png_64/wsymbol_0033_cloudy_with_light_rain_night.png'], +'weather_descriptions': ['Light Rain'], 'wind_speed': 24, 'wind_degree': 260, 'wind_dir': 'W', 'pressure': 1006, +'precip': 1.4, 'humidity': 82, 'cloudcover': 0, 'feelslike': 7, 'uv_index': 1, 'visibility': 10, 'is_day': 'no'}} +''' + +# Mapping of intents detected by the Intent & Slot Model to simple intent strings +# that the Large Language Model can understand +# We've added the misspelled intent weather.temprature because that intent is +# misspelled in /models/riva_intent_weather/1/intent_labels.csv +# To clarify further, the problem is in the outputs of the intent slot model, +# not in the sample apps or the Riva Client Python module +llm_weather_intents = { + "weather.weather":"Weather", + "context.weather":"Weather", + "weather.temperature":"Temperature", + "weather.temprature":"Temperature", # Intentional misspelling for debugging + "weather.temperature_yes_no":"Temperature", + "weather.rainfall_yes_no":"Rain", + "weather.rainfall":"Rain", + "weather.snow_yes_no":"Snow", + "weather.snow":"Snow", + "weather.cloudy":"Cloudy", + "weather.sunny":"Sunny", + "weather.humidity":"Humidity", + "weather.humidity_yes_no":"Humidity", +} + +LLM_ERROR_RESPONSE="Sorry, I could not connect to the LLM Service. Please check the configurations again." + +def text2int(textnum, numwords={}): + if not numwords: + units = [ + "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", + "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", + "sixteen", "seventeen", "eighteen", "nineteen", + ] + + tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] + + scales = ["hundred", "thousand", "million", "billion", "trillion"] + + numwords["and"] = (1, 0) + for idx, word in enumerate(units): numwords[word] = (1, idx) + for idx, word in enumerate(tens): numwords[word] = (1, idx * 10) + for idx, word in enumerate(scales): numwords[word] = (10 ** (idx * 3 or 2), 0) + + current = result = 0 + + try: + for word in textnum.split(): + if word not in numwords: + raise Exception("Illegal word: " + word) + + scale, increment = numwords[word] + current = current * scale + increment + if scale > 100: + result += current + current = 0 + + except Exception as e: + print(e) + # If an Illegal word is detected, ignore the whole weathertime + return 0 + + return result + current + + +class WeatherService: + + def __init__(self): + self.access_key = riva_config["WEATHERSTACK_ACCESS_KEY"] + self.days_of_week = {'monday': 0, 'tuesday': 1, 'wednesday': 2, 'thursday': 3, 'friday': 4, 'saturday': 5, 'sunday': 6} + self.weekend = 'weekend' + + def time_to_days(self, context): + if riva_config['VERBOSE']: + print('[Riva Weather] Time info from the query:', context['payload']) + ctxtime = False + if 'weatherforecastdaily' in context['payload']: + ctxtime = context['payload']['weatherforecastdaily'].lower() + if 'weathertime' in context['payload']: + ctxtime = context['payload']['weathertime'].lower() + if ctxtime == "week": + if 'weatherforecastdaily' in context['payload']: + ctxtime = context['payload']['weatherforecastdaily'].lower() + " " + ctxtime + else: + ctxtime = False + if 'day_of_week' in context['payload']: + ctxtime = context['payload']['day_of_week'].lower() + if ctxtime: + context['time'] = ctxtime + if 'now' in ctxtime: + return 0 + elif 'tomorrow' in ctxtime: + return 1 + elif 'next week' in ctxtime: + return 7 + elif 'yesterday' in ctxtime: + return -1 + elif 'last week' in ctxtime: + return -7 + elif ctxtime in self.days_of_week: + diff = self.days_of_week[ctxtime] - datetime.datetime.today().weekday() + if diff<0: + diff+=7 + return diff + elif self.weekend in ctxtime: + context['time'] = 'during the weekend' + return self.days_of_week['sunday'] - datetime.datetime.today().weekday() + elif 'weathertime' in context['payload']: + if not isinstance(context['payload']['weathertime'], int): + q = text2int(context['payload']['weathertime']) + else: + q = context['payload']['weathertime'] + context['time'] = "in {} {}".format(context['payload']['weathertime'], ctxtime) + if 'week' in ctxtime: + return q*7 + elif 'days' in ctxtime: + return q + return 0 + + def query_weather(self, location, response): + params = { + 'access_key': self.access_key, + 'query': location + } + try: + api_result = requests.get('http://api.weatherstack.com/current', params) + api_response = api_result.json() + if riva_config['VERBOSE']: + print("[Riva Weather] Weather API Response: " + str(api_response)) + + if 'success' in api_response and api_response['success'] == False: + response['success'] = False + return + + response['success'] = True + response['country'] = api_response['location']['country'] + response['city'] = api_response['location']['name'] + response['condition'] = api_response['current']['weather_descriptions'][0] + response['temperature_c'] = api_response['current']['temperature'] + response['temperature_c_int'] = api_response['current']['temperature'] + response['humidity'] = api_response['current']['humidity'] + response['wind_mph'] = api_response['current']['wind_speed'] + response['precip'] = api_response['current']['precip'] + except: + response['success'] = False + + def query_weather_forecast(self, location, day, response): + params = { + 'access_key': self.access_key, + 'query': location + } + try: + api_result = requests.get('http://api.weatherstack.com/current', params) + api_response = api_result.json() + + if 'success' in api_response and api_response['success'] == False: + response['success'] = False + return + response['success'] = True + response['country'] = api_response['location']['country'] + response['city'] = api_response['location']['name'] + response['condition'] = api_response['current']['weather_descriptions'][0] + response['temperature_c'] = p.number_to_words(api_response['current']['temperature']) + response['temperature_c_int'] = api_response['current']['temperature'] + response['humidity'] = p.number_to_words(api_response['current']['humidity']) + response['wind_mph'] = p.number_to_words(api_response['current']['wind_speed']) + except: + response['success'] = False + + def query_weather_historical(self, location, day, response): + params = { + 'access_key': self.access_key, + 'query': location + } + try: + api_result = requests.get('http://api.weatherstack.com/current', params) + api_response = api_result.json() + + if 'success' in api_response and api_response['success'] == False: + response['success'] = False + return + + response['success'] = True + response['country'] = api_response['location']['country'] + response['city'] = api_response['location']['name'] + response['condition'] = api_response['current']['weather_descriptions'][0] + response['temperature_c'] = p.number_to_words(api_response['current']['temperature']) + response['temperature_c_int'] = api_response['current']['temperature'] + response['humidity'] = p.number_to_words(api_response['current']['humidity']) + response['wind_mph'] = p.number_to_words(api_response['current']['wind_speed']) + + except: + response['success'] = False + +def query_llm(intent, timeinfo, weather_data): + """ + This function prompts the LLM service to paraphrase real-time weather data to a natural sounding human-like response. + + Args: + intent: The intent of the user query determined by the Intent & Slot model. For ex. weather, rain, snow, temperature, humidity etc. + timeinfo: The time of the weather request. + weather_data: The response of the fulfillment service that contains real-time weather information. + + Returns: + The weather response paraphrased by the LLM service. + """ + + # Default error response + llm_response = LLM_ERROR_RESPONSE + + # Step 1: Set the OpenAI API key + openai.api_key = llm_config["API_KEY"] + + # Real-time weather data is string formatted into a query + # which will be added to a few examples of paraphrasing weather data when querying the service. + query ='\n\nIntent: {intent}\nCondition: {condition}\nPlace: {city}\nTime: {time}\nTemperature: {temperature} C\nHumidity: {humidity} percent\nWind Speed: {wind_speed} mph\n\nMisty:'.format(intent=intent, condition=weather_data["condition"], city=weather_data["city"], time=timeinfo, temperature=weather_data["temperature_c"], humidity=weather_data["humidity"], wind_speed=weather_data["wind_mph"]) + + if llm_config["VERBOSE"]: + print("Query to LLM Service:", LLM_PROMPT_DEFAULT+query) + + try: + # Step 2: Call the LLM service to generate a completion. + # The query with real-time weather data is added to the few-shot prompt in LLM_PROMPT_DEFAULT (refer /config.py) + # The description and ranges of various parameters are present in the API reference: (https://llm.ngc.nvidia.com/openapi/api-reference + + response = openai.ChatCompletion.create( + model=llm_config["API_MODEL_NAME"], + messages=[{"role": "system", "content": LLM_PROMPT_DEFAULT}, {"role": "user", "content": query}], + temperature=llm_config["TEMPERATURE"], # this is the degree of randomness of the model's output + top_p=llm_config["TOP_P"], + n=1, + stop=llm_config["STOP_WORDS"], # stop words + max_tokens=llm_config["TOKENS_TO_GENERATE"], # tokens to generate + presence_penalty=llm_config["PRESENCE_PENALTY"], # \in [-2.0, 2.0] Is this kind of the opposite of the beam search diversity rate? + frequency_penalty=llm_config["REPETITION_PENALTY"], # \in [-2.0, 2.0] as opposed to [1.0, 2.0] as in nemollm repetition_penalty + ) + llm_response = response.choices[0].message["content"] + if llm_response.startswith("Misty: "): + llm_response = llm_response[7:] + + + if llm_config["VERBOSE"]: + print("Response from LLM Service:", llm_response) + except Exception as e: + print(e) + + + return llm_response diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/Weather.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/Weather.py new file mode 100644 index 0000000..ab42da4 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/Weather.py @@ -0,0 +1,54 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from riva_local.chatbot.stateDM.state import State +from riva_local.chatbot.stateDM.Util import WeatherService, query_llm, llm_weather_intents + + +DEFAULT_MESSAGE = "Unfortunately the weather service is not available at this time. Check your connection to weatherstack.com, set a different API key in your configuration or else try again later." + + +class Weather(State): + def __init__(self, bot, uid): + super(Weather, self).__init__("Weather", bot, uid) + self.next_state = None + + # # NOTE: weather forecast and weather historical are paid options in weatherstack + # # forecast and historical methods here return the current data only for now. + + def run(self, request_data): + ws = WeatherService() + + # Extract time information + if 'weatherforecastdaily' in request_data['context']['payload']: + timeinfo = request_data['context']['payload']['weatherforecastdaily'] + elif 'weathertime' in request_data['context']['payload']: + timeinfo = request_data['context']['payload']['weathertime'] + else: + timeinfo = "Today" # Default + + # Convert LLM Model intents to strings that LLM can understand + if request_data['context']['intent'] in llm_weather_intents: + response = {} + ws.query_weather(request_data['context']['location'], response) + + # Query the LLM service to paraphrase the weather-data to a natural sounding response + if response['success']: + message = query_llm(intent=llm_weather_intents[request_data['context']['intent']], + timeinfo=timeinfo, + weather_data=response) + else: + message = DEFAULT_MESSAGE + else: + # TODO: Add support for small talk + message = "Sorry, I did not understand the query." + + request_data['context'].update({'weather_status': message}) + + # Update the response text with the weather status + request_data.update({'response': + self.construct_message(request_data, message)}) \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__init__.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/Util.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/Util.cpython-38.pyc new file mode 100644 index 0000000..c4b06d8 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/Util.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/Weather.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/Weather.cpython-38.pyc new file mode 100644 index 0000000..707faa4 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/Weather.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..dea2400 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state.cpython-38.pyc new file mode 100644 index 0000000..ad35a3c Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state_data.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state_data.cpython-38.pyc new file mode 100644 index 0000000..0fedf6b Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state_data.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state_machine.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state_machine.cpython-38.pyc new file mode 100644 index 0000000..b65a322 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/state_machine.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/states.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/states.cpython-38.pyc new file mode 100644 index 0000000..f3f131a Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/__pycache__/states.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state.py new file mode 100644 index 0000000..20abc2b --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state.py @@ -0,0 +1,44 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from abc import ABC, abstractmethod +import sys + +class State(ABC): + """ State is an abstract class """ + + def __init__(self, name, bot, uid): + self.name = name # Name of the state + self.bot = bot # Name of the chatbot eg. "rivaWeather" + self.uid = uid + self.next_state = None + + @abstractmethod + def run(self, request_data): + assert 0, "Run not implemented!" + + def next(self): + # This should only be run after populating next_state + return self.next_state + + def construct_message(self, request_data, text): + """ Constructs the response frame, + appending to a prev response if that exists """ + message = {'type': 'text', + 'payload': {'text': text}, + 'delay': 0} + + prev_response = request_data.get('response', False) + + # If there was an old response, append the new response to the list + if prev_response: + prev_response.append(message) + # Else create a list containing the response + else: + prev_response = [message] + + return prev_response \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state_data.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state_data.py new file mode 100644 index 0000000..7de7a9b --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state_data.py @@ -0,0 +1,36 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +# This is used for finding the state to transition to based on intent +# We've added the misspelled intent weather.temprature because that intent is +# misspelled in /models/riva_intent_weather/1/intent_labels.csv +# To clarify further, the problem is in the outputs of the intent slot model, +# not in the sample apps or the Riva Client Python module +intent_transitions = { + 'rivaWeather': { + 'weather.qa_answer': 'checkWeatherLocation', + 'weather.weather': 'checkWeatherLocation', + 'context.weather': 'checkWeatherLocation', + 'weather.temperature': 'checkWeatherLocation', + 'weather.temprature': 'checkWeatherLocation', # Intentional misspelling for debugging + 'weather.sunny': 'checkWeatherLocation', + 'weather.cloudy': 'checkWeatherLocation', + 'weather.snow': 'checkWeatherLocation', + 'weather.rainfall': 'checkWeatherLocation', + 'weather.snow_yes_no': 'checkWeatherLocation', + 'weather.rainfall_yes_no': 'checkWeatherLocation', + 'weather.temperature_yes_no': 'checkWeatherLocation', + 'weather.humidity': 'checkWeatherLocation', + 'weather.humidity_yes_no': 'checkWeatherLocation', + 'navigation.startnavigationpoi': 'checkWeatherLocation', + 'navigation.geteta': 'checkWeatherLocation', + 'navigation.showdirection': 'checkWeatherLocation', + 'riva_error': 'error', + 'navigation.showmappoi': 'error', + 'nomatch.none': 'error' + } +} \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state_machine.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state_machine.py new file mode 100644 index 0000000..420c1e2 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/state_machine.py @@ -0,0 +1,59 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import copy +from riva_local.chatbot.stateDM.states import userInput, userLocationInput +from config import riva_config + +verbose = riva_config["VERBOSE"] + +############################################################################### +# stateDM (Simple Dialog Manager): A Finite State Machine +############################################################################### +class StateMachine: + def __init__(self, user_conversation_index, init): + self.uid = user_conversation_index + self.bot = "rivaWeather" + if verbose: + print("[stateDM] Initializing the state machine for uid: ", self.uid) + self.currentState = init(self.bot, self.uid) + + def execute_state(self, bot, context, text): + # Fresh request frame + request_data = {'context': context, + 'text': text, + 'uid': self.uid, + 'payload': {}} + + # TODO: Add support for !undo (saving previous context) and !reset + + # Keep executing the state machine until a user input is required + # i.e. stop when state is either InputUser or InputContext + while True: + # Run the current state + if verbose: + print("[stateDM] Executing state:", + self.currentState.name) + self.currentState.run(request_data) + nextState = self.currentState.next() + + # If the next state exists + if nextState is not None: + # Create an object from the next state + self.currentState = nextState(self.bot, self.uid) + # If the next state requires user input, just return + # WARNING: Can go into infinite loop if states don't have + # next_state configured properly + if nextState == userInput or nextState == userLocationInput: + return request_data + + # If no next state exists, wait for user input now + else: + if verbose: + print("[stateDM] No next state, waiting for user input") + self.currentState = userInput(self.bot, self.uid) + return request_data diff --git a/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/states.py b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/states.py new file mode 100644 index 0000000..c9c9709 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/chatbot/stateDM/states.py @@ -0,0 +1,139 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from riva_local.chatbot.stateDM.state import State +from riva_local.nlp.nlp import get_entities +from riva_local.chatbot.stateDM.state_data import intent_transitions +from riva_local.chatbot.stateDM.Weather import Weather +import sys +from config import riva_config + +verbose = riva_config["VERBOSE"] + + +class initialState(State): + def __init__(self, bot, uid): + super(initialState, self).__init__("initialState", bot, uid) + + def run(self, request_data): + text = "Hi, welcome to Misty's weather service. How may I help you?" + + # Update response with welcome text + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userInput + + +class userInput(State): + def __init__(self, bot, uid): + super(userInput, self).__init__("userInput", bot, uid) + self.next_state = None + + def get_state(self, class_str, default): + return getattr(sys.modules[__name__], class_str, default) + + def run(self, request_data): + # Get response from Riva NLU + response = get_entities(request_data['text'], "riva") + response_intent = response.get('intent', False) + + # Fetch the transitions dict for the bot + intents_project = intent_transitions[self.bot] + + # If a valid intent was detected + if response_intent: + # If a valid state exists for the response intent AND + # the response intent is different from the one already in context + if intents_project.get(response_intent, False) and \ + response_intent != request_data['context'].get('intent', False): + self.next_state = self.get_state(intents_project.get(response_intent, False), None) + + # update request_data with response and next_state + request_data['context'].update(response) + return + + # If intent exists in the context, use that + if 'intent' in request_data['context']: + # Populate context with response (eg. new entity location value), except the intent + request_data['context'].update({x: response[x] for x in response if x not in 'intent'}) + self.next_state = self.get_state(intents_project.get(request_data['context']['intent'], False), None) + return + + +class userLocationInput(State): + def __init__(self, bot, uid): + super(userLocationInput, self).__init__("userLocationInput", bot, uid) + + def run(self, request_data): + response = get_entities(request_data['text'], "riva") + + # Updates all keys except intent + request_data['context'].update( + {x: response[x] for x in response if x not in 'intent'}) + + # Check if the required entities (location here) are present + # If present, proceed to Weather State + if 'location' in response: + # Move to Weather State + self.next_state = Weather + else: + # Else, proceed to ErrorState + self.next_state = error + + +class checkWeatherLocation(State): + def __init__(self, bot, uid): + super(checkWeatherLocation, self).__init__( + "checkWeatherLocation", bot, uid) + + def run(self, request_data): + # Check if all entities (location) required for informing weather exists + location = request_data['context'].get("location", False) + + if location: + # If location exists, then call Weather class to check the weather location + self.next_state = Weather + else: + # If not, then asks location and moves to userLocationInput to fetch it + text = "For which location?" + + # Update response asking the user location, intent stays the same + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userLocationInput + + +class error(State): + def __init__(self, bot, uid): + super(error, self).__init__("error", bot, uid) + + def run(self, request_data): + text = "Sorry, I couldn't get you!" + + # Update response with error text + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userInput + + +# TODO: This state is not in use currently, +# add this if end of conversation is required +class end(State): + def __init__(self, bot, uid): + super(end, self).__init__("end", bot, uid) + + def run(self, request_data): + text = "Bye!" + + # Update response with end state text + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userInput diff --git a/virtual-assistant-chatgpt/riva_local/nlp/__init__.py b/virtual-assistant-chatgpt/riva_local/nlp/__init__.py new file mode 100644 index 0000000..d90b5a1 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/nlp/__init__.py @@ -0,0 +1,8 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from .nlp import * \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/nlp/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/nlp/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..e0449f2 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/nlp/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/nlp/__pycache__/nlp.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/nlp/__pycache__/nlp.cpython-38.pyc new file mode 100644 index 0000000..8a5694b Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/nlp/__pycache__/nlp.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/nlp/nlp.py b/virtual-assistant-chatgpt/riva_local/nlp/nlp.py new file mode 100644 index 0000000..8d59277 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/nlp/nlp.py @@ -0,0 +1,218 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import riva.client +from riva.client.proto.riva_nlp_pb2 import ( + AnalyzeIntentResponse, + NaturalQueryResponse, + TokenClassResponse +) + +import grpc +from config import riva_config, nlp_config +import requests +import json + +# QA api-endpoint +QA_API_ENDPOINT = nlp_config["QA_API_ENDPOINT"] +enable_qa = riva_config["ENABLE_QA"] +verbose = riva_config["VERBOSE"] + +auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) +riva_nlp = riva.client.NLPService(auth) + + +def get_qa_answer(context, question, p_threshold): + # if hasattr(resp, 'intent'): + # entities['intent'] = resp.intent.class_name + + # data to be sent to api + data = { + "question": question, + "context": context + } + # sending post request and saving response as response object + r = requests.post(QA_API_ENDPOINT, json=data) + + # extracting response text + qa_resp = json.loads(r.text) + # print("The response from QA server is :%s"%qa_response) + + if verbose: + print("[Riva NLU] The answer is :%s" % qa_resp['result']) + print("[Riva NLU] The probability is :%s" % qa_resp['p']) + + if qa_resp['result'] == '': + print("[Riva NLU] QA returned empty string.") + + if qa_resp['p'] < p_threshold: + print("[Riva NLU] QA response lower than threshold - ", p_threshold) + # qa_resp['result'] = "I am not too sure about what you meant. " + qa_resp['result'] + # return qa_resp + + return qa_resp + + +if enable_qa == "true": + # test question and passage to be sent to api + riva_test = "I am Riva. I can talk about the weather. My favorite season is spring. I know the weather info " \ + "from Weatherstack api. I have studied the weather all my life." + test_question = "What is your name?" + p_threshold = 0.4 + get_qa_answer(riva_test, test_question, p_threshold) + + +def get_intent(resp, entities): + if hasattr(resp, 'intent'): + entities['intent'] = resp.intent.class_name + + +def get_slots(resp, entities): + entities['payload'] = dict() + all_entities_class = {} + all_entities = [] + if hasattr(resp, 'slots'): + for i in range(len(resp.slots)): + slot_class = resp.slots[i].label[0].class_name.replace("\r", "") + token = resp.slots[i].token.replace("?", "").replace(",", "").replace(".", "").replace("[SEP]", "").strip() + score = resp.slots[i].label[0].score + if slot_class and token: + if slot_class == 'weatherplace' or slot_class == 'destinationplace': + entity = { "value": token, + "confidence": score, + "entity": "location" } + else: + entity = { "value": token, + "confidence": score, + "entity": slot_class } + all_entities_class[entity["entity"]] = 1 + all_entities.append(entity) + for cl in all_entities_class: + partial_entities = list(filter(lambda x: x["entity"] == cl, all_entities)) + partial_entities.sort(reverse=True, key=lambda x: x["confidence"]) + for entity in partial_entities: + if cl == "location": + entities['location'] = entity["value"] + else: + entities['payload'][cl] = entity["value"] + break + + +def get_riva_output(text): + # Submit an AnalyzeIntent request. We do not provide a domain with the query, so a domain + # classifier is run first, and based on the inferred value from the domain classifier, + # the query is run through the appropriate intent/slot classifier + # Note: the detected domain is also returned in the response. + try: + # The is appended to "riva_intent_" to look for a model "riva_intent_" + # So the model "riva_intent_" needs to be preloaded in riva server. + # In this case the domain is weather and the model being used is "riva_intent_weather-misc". + options = riva.client.AnalyzeIntentOptions(lang='en-US', domain='weather') + + resp: AnalyzeIntentResponse = riva_nlp.analyze_intent(text, options) + + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request") + return {'riva_error': 'riva_error'} + entities = {} + get_intent(resp, entities) + get_slots(resp, entities) + if 'location' not in entities: + if verbose: + print(f"[Riva NLU] Did not find any location in the string: {text}\n" + "[Riva NLU] Checking again using NER model") + try: + model_name = "riva_ner" + resp_ner: TokenClassResponse = riva_nlp.classify_tokens(text, model_name) + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request (riva_ner)") + return {'riva_error': 'riva_error'} + + if verbose: + print(f"[Riva NLU] NER response results: \n {resp_ner.results[0].results}\n") + print("[Riva NLU] Location Entities:") + loc_count = 0 + for result in resp_ner.results[0].results: + if result.label[0].class_name == "LOC": + if verbose: + print(f"[Riva NLU] Location found: {result.token}") # Flow unhandled for multiple location input + loc_count += 1 + entities['location'] = result.token + if loc_count == 0: + if verbose: + print("[Riva NLU] No location found in string using NER LOC") + print("[Riva NLU] Checking response domain") + if resp.domain.class_name == "nomatch.none": + # as a final resort try QA API + if enable_qa == "true": + if verbose: + print("[Riva NLU] Checking using QA API") + riva_misty_profile = requests.get(nlp_config["RIVA_MISTY_PROFILE"]).text # Live pull from Cloud + qa_resp = get_qa_answer(riva_misty_profile, text, p_threshold) + if not qa_resp['result'] == '': + if verbose: + print("[Riva NLU] received qa result") + entities['intent'] = 'qa_answer' + entities['answer_span'] = qa_resp['result'] + entities['query'] = text + else: + entities['intent'] = 'riva_error' + else: + entities['intent'] = 'riva_error' + if verbose: + print("[Riva NLU] This is what entities contain: ", entities) + return entities + + +def get_riva_output_qa_only(text): + # Submit an AnalyzeIntentRequest. We do not provide a domain with the query, so a domain + # classifier is run first, and based on the inferred value from the domain classifier, + # the query is run through the appropriate intent/slot classifier + # Note: the detected domain is also returned in the response. + + entities = {} + try: + if enable_qa == "true": + if verbose: + print("[Riva NLU] Checking using QA API") + riva_mark_KB = requests.get(nlp_config["RIVA_MARK_KB"]).text # Live pull from Cloud + qa_resp = get_qa_answer(riva_mark_KB, text, p_threshold) + if not qa_resp['result'] == '': + if verbose: + print("[Riva NLU] received qa result") + entities['intent'] = 'qa_answer' + entities['answer_span'] = qa_resp['result'] + entities['query'] = text + else: + entities['intent'] = 'riva_error' + else: + entities['intent'] = 'riva_error' + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request") + return {'riva_error': 'riva_error'} + if verbose: + print("[Riva NLU] This is what entities contain: ", entities) + return entities + + +def get_entities(text, nlp_type): + if nlp_type is None: + nlp_type = "empty" + + ent_out = {} + if nlp_type == "empty": + ent_out.update({'raw_text': str(text)}) + elif nlp_type == "riva": + riva_out = get_riva_output(text) + ent_out.update(riva_out) + elif nlp_type == "riva_mark": + riva_out = get_riva_output_qa_only(text) + ent_out.update(riva_out) + return ent_out diff --git a/virtual-assistant-chatgpt/riva_local/tts/__init__.py b/virtual-assistant-chatgpt/riva_local/tts/__init__.py new file mode 100644 index 0000000..a54072c --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/tts/__init__.py @@ -0,0 +1,2 @@ +from .tts import * +from .tts_stream import * \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/tts/__pycache__/__init__.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/tts/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..31c2266 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/tts/__pycache__/__init__.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/tts/__pycache__/tts.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/tts/__pycache__/tts.cpython-38.pyc new file mode 100644 index 0000000..6be5360 Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/tts/__pycache__/tts.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/tts/__pycache__/tts_stream.cpython-38.pyc b/virtual-assistant-chatgpt/riva_local/tts/__pycache__/tts_stream.cpython-38.pyc new file mode 100644 index 0000000..6280b7f Binary files /dev/null and b/virtual-assistant-chatgpt/riva_local/tts/__pycache__/tts_stream.cpython-38.pyc differ diff --git a/virtual-assistant-chatgpt/riva_local/tts/tts.py b/virtual-assistant-chatgpt/riva_local/tts/tts.py new file mode 100644 index 0000000..f4536e3 --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/tts/tts.py @@ -0,0 +1,114 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import grpc +import riva.client +from six.moves import queue +from config import riva_config, tts_config +import numpy as np +import time + +# Default ASR parameters - Used in case config values not specified in the config.py file +VERBOSE = False +SAMPLE_RATE = 22050 +LANGUAGE_CODE = "en-US" +VOICE_NAME = "English-US.Female-1" + +class TTSPipe(object): + """Opens a gRPC channel to Riva TTS to synthesize speech + from text in batch mode.""" + + def __init__(self): + self.verbose = tts_config["VERBOSE"] if "VERBOSE" in tts_config else VERBOSE + self.sample_rate = tts_config["SAMPLE_RATE"] if "SAMPLE_RATE" in tts_config else SAMPLE_RATE + self.language_code = tts_config["LANGUAGE_CODE"] if "LANGUAGE_CODE" in tts_config else LANGUAGE_CODE + self.voice_name = tts_config["VOICE_NAME"] if "VOICE_NAME" in tts_config else VOICE_NAME + self.audio_encoding = ra.AudioEncoding.LINEAR_PCM + self._buff = queue.Queue() + self.closed = False + self._flusher = bytes(np.zeros(dtype=np.int16, shape=(self.sample_rate, 1))) # Silence audio + self.current_tts_duration = 0 + + def start(self): + if self.verbose: + print('[Riva TTS] Creating Stream TTS channel: {}'.format(riva_config["RIVA_SPEECH_API_URL"])) + self.auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) + self.riva_tts = riva.client.SpeechSynthesisService(self.auth) + + def reset_current_tts_duration(self): + self.current_tts_duration = 0 + + def get_current_tts_duration(self): + return self.current_tts_duration + + def fill_buffer(self, in_data): + """To collect text responses from the state machine output, into a buffer.""" + if len(in_data): + self._buff.put(in_data) + + def close(self): + self.closed = True + self._buff.queue.clear() + self._buff.put(None) # means the end + del(self.channel) + + def get_speech(self): + """Returns speech audio from text responses in the buffer""" + self.start() + wav_header = self.gen_wav_header(self.sample_rate, 16, 1, 0) + yield bytes(wav_header) + flush_count = 0 + while not self.closed: + if not self._buff.empty(): # Enter if queue/buffer is not empty. + try: + text = self._buff.get(block=False, timeout=0) + if self.verbose: + print('[Riva TTS] Pronounced Text: ', text) + responses = self.riva_tts.synthesize( + text = text, + language_code = self.language_code, + encoding = riva.client.AudioEncoding.LINEAR_PCM, + sample_rate_hz = self.sample_rate, + voice_name = self.voice_name + ) + datalen = len(resp.audio) // 2 + data16 = np.ndarray(buffer=resp.audio, dtype=np.int16, shape=(datalen, 1)) + speech = bytes(data16.data) + duration = len(data16) * 2 / (self.sample_rate * 1 * 16 / 8) + if self.verbose: + print(f'[Riva TTS] The datalen is: {datalen}') + print(f'[Riva TTS] Duration of audio is: {duration}') + self.current_tts_duration = duration + yield speech + flush_count = 5 + continue + except Exception as e: + print('[Riva TTS] ERROR:') + print(str(e)) + + # To flush out remaining audio from client buffer + if flush_count > 0: + yield self._flusher + flush_count -= 1 + continue + time.sleep(0.1) # Set the buffer check rate. + + def gen_wav_header(self, sample_rate, bits_per_sample, channels, datasize): + o = bytes("RIFF", 'ascii') # (4byte) Marks file as RIFF + o += (datasize + 36).to_bytes(4, 'little') # (4byte) File size in bytes excluding this and RIFF marker + o += bytes("WAVE", 'ascii') # (4byte) File type + o += bytes("fmt ", 'ascii') # (4byte) Format Chunk Marker + o += (16).to_bytes(4, 'little') # (4byte) Length of above format data + o += (1).to_bytes(2, 'little') # (2byte) Format type (1 - PCM) + o += channels.to_bytes(2, 'little') # (2byte) + o += sample_rate.to_bytes(4, 'little') # (4byte) + o += (sample_rate * channels * bits_per_sample // 8).to_bytes(4, 'little') # (4byte) + o += (channels * bits_per_sample // 8).to_bytes(2, 'little') # (2byte) + o += bits_per_sample.to_bytes(2, 'little') # (2byte) + o += bytes("data", 'ascii') # (4byte) Data Chunk Marker + o += datasize.to_bytes(4, 'little') # (4byte) Data size in bytes + return o \ No newline at end of file diff --git a/virtual-assistant-chatgpt/riva_local/tts/tts_stream.py b/virtual-assistant-chatgpt/riva_local/tts/tts_stream.py new file mode 100644 index 0000000..e169c1c --- /dev/null +++ b/virtual-assistant-chatgpt/riva_local/tts/tts_stream.py @@ -0,0 +1,123 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import grpc +import riva.client + +from six.moves import queue +from config import riva_config, tts_config +import numpy as np +import time + +# Default ASR parameters - Used in case config values not specified in the config.py file +VERBOSE = False +SAMPLE_RATE = 22050 +LANGUAGE_CODE = "en-US" +VOICE_NAME = "English-US.Female-1" + +class TTSPipe(object): + """Opens a gRPC channel to Riva TTS to synthesize speech + from text in streaming mode.""" + + def __init__(self): + self.verbose = tts_config["VERBOSE"] if "VERBOSE" in tts_config else VERBOSE + self.sample_rate = tts_config["SAMPLE_RATE"] if "SAMPLE_RATE" in tts_config else SAMPLE_RATE + self.language_code = tts_config["LANGUAGE_CODE"] if "LANGUAGE_CODE" in tts_config else LANGUAGE_CODE + self.voice_name = tts_config["VOICE_NAME"] if "VOICE_NAME" in tts_config else VOICE_NAME + self.audio_encoding = riva.client.AudioEncoding.LINEAR_PCM + self._buff = queue.Queue() + self.closed = False + self._flusher = bytes(np.zeros(dtype=np.int16, shape=(self.sample_rate, 1))) # Silence audio + self.current_tts_duration = 0 + + def start(self): + if self.verbose: + print('[Riva TTS] Creating Stream TTS channel: {}'.format(riva_config["RIVA_SPEECH_API_URL"])) + self.auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) + self.riva_tts = riva.client.SpeechSynthesisService(self.auth) + + def reset_current_tts_duration(self): + self.current_tts_duration = 0 + + def get_current_tts_duration(self): + return self.current_tts_duration + + def fill_buffer(self, in_data): + """To collect text responses from the state machine output, into a buffer.""" + if len(in_data): + self._buff.put(in_data) + + def close(self): + self.closed = True + self._buff.queue.clear() + self._buff.put(None) # means the end + del(self.channel) + + def get_speech(self): + """Returns speech audio from text responses in the buffer""" + self.start() + wav_header = self.gen_wav_header(self.sample_rate, 16, 1, 0) + yield bytes(wav_header) + flush_count = 0 + while not self.closed: + if not self._buff.empty(): # Enter if queue/buffer is not empty. + try: + text = self._buff.get(block=False, timeout=0) + if self.verbose: + print('[Riva TTS] Pronounced Text: ', text) + + if self.verbose: + print('[Riva TTS] Starting TTS streaming') + duration = 0 + self.current_tts_duration = 0 + + # <---------- EXERCISE: Fill-in the line of code below -----------> + # responses = self.self.riva_tts.synthesize(xx) ? + responses = self.riva_tts.synthesize_online( + text = text, + language_code = self.language_code, + encoding = riva.client.AudioEncoding.LINEAR_PCM, + sample_rate_hz = self.sample_rate, + voice_name = self.voice_name + ) + + for resp in responses: + datalen = len(resp.audio) // 2 + data16 = np.ndarray(buffer=resp.audio, dtype=np.int16, shape=(datalen, 1)) + speech = bytes(data16.data) + duration += len(data16) * 2 / (self.sample_rate * 1 * 16 / 8) + self.current_tts_duration += duration + if self.verbose: + print(f'[Riva TTS] Duration of audio is: {duration}') + yield speech + except Exception as e: + print('[Riva TTS] ERROR:') + print(str(e)) + flush_count = 5 + continue + # To flush out remaining audio from client buffer + if flush_count > 0: + yield self._flusher + flush_count -= 1 + continue + time.sleep(0.1) # Set the buffer check rate. + + def gen_wav_header(self, sample_rate, bits_per_sample, channels, datasize): + o = bytes("RIFF", 'ascii') # (4byte) Marks file as RIFF + o += (datasize + 36).to_bytes(4, 'little') # (4byte) File size in bytes excluding this and RIFF marker + o += bytes("WAVE", 'ascii') # (4byte) File type + o += bytes("fmt ", 'ascii') # (4byte) Format Chunk Marker + o += (16).to_bytes(4, 'little') # (4byte) Length of above format data + o += (1).to_bytes(2, 'little') # (2byte) Format type (1 - PCM) + o += channels.to_bytes(2, 'little') # (2byte) + o += sample_rate.to_bytes(4, 'little') # (4byte) + o += (sample_rate * channels * bits_per_sample // 8).to_bytes(4, 'little') # (4byte) + o += (channels * bits_per_sample // 8).to_bytes(2, 'little') # (2byte) + o += bits_per_sample.to_bytes(2, 'little') # (2byte) + o += bytes("data", 'ascii') # (4byte) Data Chunk Marker + o += datasize.to_bytes(4, 'little') # (4byte) Data size in bytes + return o diff --git a/virtual-assistant-nemo-llm/README.md b/virtual-assistant-nemo-llm/README.md new file mode 100644 index 0000000..978ffb9 --- /dev/null +++ b/virtual-assistant-nemo-llm/README.md @@ -0,0 +1,136 @@ +# Weather Virtual Assistant Example + +## Overview + +The Virtual Assistant (VA) sample demonstrates how to use NVIDIA NeMo LLM along with Riva AI Services to build a simple but complete conversational AI application. It demonstrates receiving input via speech from the user, interpreting the query via an intention recognition and slot filling approach, leveraging the NeMo LLM to generate a natural sounding human-like response, and speaking this back to the user in a natural voice. + +## Prerequisites + +- This demo uses NVIDIA Riva to support Speech AI capabilities like Automatic Speech Recognition (ASR) and Text-to-Speech (TTS). To run NVIDIA Riva Speech AI services, please ensure you have the pre-requisites mentioned [here](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html#data-center). +- For running this sample application, you'll need + - Access to the NeMo LLM Service through [NVIDIA NGC](https://www.nvidia.com/en-us/gpu-cloud/). You will require your [NGC API key](https://docs.nvidia.com/ngc/ngc-overview/index.html#generating-api-key) to access the service through the API in this sample application. + - a Linux x86_64 environment with [pip](https://pypi.org/project/pip/) and Python 3.10 installed. + - The [weatherstack API access key](https://weatherstack.com/documentation). The VA uses weatherstack for weather fulfillment, that is when the weather intents are recognized, the real-time weather information is fetched from weatherstack. Sign up to the free tier of [weatherstack](https://weatherstack.com/), and get your API access key. + - a headset and microphone (for example, a Logitech H390 USB Computer Headset) to communicate with the app. + + +### Setup + +1. Download the tar archive with the sample code: `https://llm.ngc.nvidia.com/sample-files/llm/riva-llm-weather-va.tar.xz` + +2. Extract the archive: +```bash +tar xf riva-llm-weather-va.tar.xz +cd riva-llm-weather-va +``` + +3. Create and enable a Python [virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/#creating-a-virtual-environment). For example, with Python 3.10: +``` +python3.10 -m venv apps-env +source apps-env/bin/activate +``` + +After activating, checking the Python version should reveal the one you created the environment with. For example: +``` +python3 --version +``` +*Python 3.10.6* + + +4. Install the libraries necessary for the virtual assistant, including the Riva client library: + 1. Install the Riva client library. + ``` + pip install nvidia-riva-client + ``` + 2. Install weatherbot web application dependencies. `requirements.txt` captures all Python dependencies needed for weatherbot web application. + ```bash + pip install -r requirements.txt # Tested with Python 3.10 + ``` + 3. Install the NeMO LLM Client library + ```bash + pip install nemollm + ``` + +### Running the demo +1. Start the Riva Speech Server, if not already done. Follow the steps in the [Riva Quick Start Guide](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html). This will allow Speech AI capabilities which are required for the demo. **Note the IP & port** where the Riva server is running. By default it will run at IP:50051 + +2. Edit the configuration file [config.py](./config.py) + 1. In `riva_config` set: + * The Riva speech server URL. This is the endpoint where the Riva services can be accessed. + * The [weatherstack API access key](https://weatherstack.com/documentation). The VA uses weatherstack for weather fulfillment, that is when the weather intents are recognized, real-time weather information is fetched from weatherstack. Sign up to the free tier of [weatherstack](https://weatherstack.com/), and get your API access key. + 2. In `llm_config` set: + * The NGC API Access key + * (Optionally) you can also choose the GPT model to use among "gpt5b", "gpt20b" and "gpt530b", and/or modify the parameters for generation. + +The code snippets will look like the example below. +```python3 +riva_config = { + "RIVA_SPEECH_API_URL": ":", # Replace the IP & port with your hosted Riva endpoint + ... + "WEATHERSTACK_ACCESS_KEY": "", # Get your access key at - https://weatherstack.com/ + ... +} +... +llm_config = { + ... + "API_MODEL_NAME":"gpt20b", + "API_KEY": # You can get this by logging in your NGC account, and navigating to your user account at the top right corner, and then selecting "Get API Key". If you already have a key, you don't need to generate it again. + ... +} +``` + +3. Run the virtual assistant application +```bash +python3 main.py +``` + +4. Open the browser to **https://IP:8009/rivaWeather**, where the IP is for the machine where the application is running. For instance, go to if the app is running in your local machine. + +5. Speak to the virtual assistant through your microphone or type-in your text, asking a weather related query. To hear back text-to-speech audio of the LLM response, click on "Unmute System Speech" on the right bottom corner of the UI. + +`NOTE:` To learn about the call to the LLM Service, please refer to the `query_llm` method in `riva_local/chatbot/stateDM/Util.py`. + +## Sample Use Cases +It is possible to ask the bot the following types of questions: + +* What is the weather in Berlin? + +* What is the weather? + * For which location? + +* What’s the weather like in San Francisco tomorrow? + * What about in California City? + +* What is the temperature in Paris on Friday? + +* How hot is it in Berlin today? + +* Is it currently cold in San Francisco? + +* Is it going to rain in Detroit tomorrow? + +* How much rain in Seattle? + +* Will it be sunny next week in Santa Clara? + +* Is it cloudy today? + +* Is it going to snow tomorrow in Milwaukee? + +* How much snow is there in Toronto currently? + +* How humid is it right now? + +* What is the humidity in Miami? + +* What's the humidity level in San Diego? + +## Limitations +* The sample supports intents for weather, temperature, rain, humidity, sunny, cloudy and snowfall checks. It does not support general conversational queries or other domains. +* The sample supports only 1 slot for city. +* The sample supports up to four concurrent users. This restriction is because of the web framework (Flask and Flask-SocketIO) that is being used. The socket connection is to stream audio to (TTS) and from (ASR); you are unable to sustain more than four concurrent socket connections. +* The chatbot application is not optimized for low latency in the case of multiple concurrent users. +* Some erratic issues have been observed with the chatbot sample on the Firefox browser. The most common issue is the TTS output being taken in as input by ASR for certain microphone gain values. + +## License +The [NVIDIA Riva License Agreement](https://developer.nvidia.com/riva/ga/license) is included with the product. Licenses are also available along with the model application zip file. By pulling and using the Riva SDK container, downloading models, or using the sample applications here, you accept the terms and conditions of these licenses.
diff --git a/virtual-assistant-nemo-llm/client/webapplication/cert.pem b/virtual-assistant-nemo-llm/client/webapplication/cert.pem new file mode 100644 index 0000000..f8951d3 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/cert.pem @@ -0,0 +1,34 @@ +-----BEGIN CERTIFICATE----- +MIIF8TCCA9mgAwIBAgIUCWhCFWAPSu7g+Tn0D4+oYfLn9xUwDQYJKoZIhvcNAQEL +BQAwgYcxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxpZm9ybmlhMRUwEwYDVQQH +DAxTYW5GcmFuY2lzY28xEDAOBgNVBAoMB1FWQU5UVk0xCzAJBgNVBAsMAlFWMQsw +CQYDVQQDDAJTRDEgMB4GCSqGSIb3DQEJARYRcXZhbnR2bUBnbWFpbC5jb20wHhcN +MjAwMTE1MTg0NTAyWhcNMjEwMTE0MTg0NTAyWjCBhzELMAkGA1UEBhMCVVMxEzAR +BgNVBAgMCkNhbGlmb3JuaWExFTATBgNVBAcMDFNhbkZyYW5jaXNjbzEQMA4GA1UE +CgwHUVZBTlRWTTELMAkGA1UECwwCUVYxCzAJBgNVBAMMAlNEMSAwHgYJKoZIhvcN +AQkBFhFxdmFudHZtQGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCC +AgoCggIBAM62M4FzhOfntQARneCixYTl2mi8owa2spm600Lopy8Fwg2e5jkY/pyg +znK7foHBvW3eJKxENtwLKIUqNYyOA2F+96FW1fZ9MUF6zcYcccFimjjKr6tYrqDf +GIqAgxOcv4Syv/DaIm3tKDwIYFFtB/kALyG0vBHWV40fvOoqTwCQUSbptM9GBRZe ++beMUS83Zwk4OPPIiSX+P+DEzsBAqUbqU50ilJrzM7VZ+QpVooF/De4moRDSknVs +h/Zbn9a8Wbfb1XpBzm19mnO1vZ6CENUIFdCxvW4qrmfIu346bd+stXcwcJp7bLP0 +vmVnK5DpGIjmzo7n97nFJUO3kFykbJcWpTIwWSc7jsnak+HkwxXAYWZmz0sABDLA +QBK37NuYCqsDxvOibC2X4Y0oRxYut8R0nUYXOtKHBvR5Ug9njIov9lsV8acIm7s9 +r0fignHgFOXmTDMIjK469LwXf0vd1Wy6tZxjNS2ZOG+eyc5sKHCyVP6dEl8xYH6i +/oePn5PmzSr220comUny5NVkUewDYo63A90bQ4X3tdeI/XsPZDXhgQC36+wKFKgy +6WULQhZT4EpxMbiYNF8IjTr6w/5mk1oOT4/nRedcb7/wxlZiL5r/DJlMVuP7zlfq +n3CvLIlpiy6E2b1c7LU+mvVnkq0LQrD+yK30p3JSVwabm6GeY/tnAgMBAAGjUzBR +MB0GA1UdDgQWBBRs9s5GEerZSVFVKGAUIOxNDbE7HDAfBgNVHSMEGDAWgBRs9s5G +EerZSVFVKGAUIOxNDbE7HDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUA +A4ICAQAv3d5cMLwb76ikYt37P10ggPOjyZwAs90mR/2f25+ADyUaX/wzHtitaA8t +cXfpm4pDICRtJjg5hb6A2Wh+ws/gODbyD915sqYDC6pz0FTxh3n2/BcCQZJWa7Oy +9k7RwixwIYjvGzkWaje9Xi5Jte7tBj4358QFUKKWiWarB+nl2MBAW3wKMRO5stst +pNZ2keaMeG/a5/Ms0PLIeSxN5oUNazs1m3NLB7XFx1ewMKdouamhuCDRyoPixtfm +pas94/IUdGyY6WYaZgGdQtBrM9ro8NtJNFiFQImI0NaQJQCgl45pmapceNzHDHgR +XHnbirr7rRJX6UutrEHNCZcwA5sFQ++AtQycW20vTGJZ1cO+mCYvWEKN0PC4KRhY +Aw8q8ROkX9RvxHl3WGFdpIFXAZtI8O1d7G9ySG1xBLWyYW7s58vdFnsqmjz4Y59b +MWWCIBIAhxLFBRrZ7KEGUj0lJjEuRZ9lkfS4CAAX2gUSlgU/GfUWe8R3bzAqC57t +awACY7PS3m29ILHuKEU4qRarSWdYaJP7ETc6TQDgxXVXA8NTQOMFj73zYILl7MYh +DXvHzk8xKGWqUv2gmKylsqQZPm/mgn/dssOV19TB6LOszX2rJ0ddZ63on5wJqDGl +6BiPGopQSm/mJPNfO3JjvetTEumfinvsac2eTZ99zbTiJ0Pyfw== +-----END CERTIFICATE----- diff --git a/virtual-assistant-nemo-llm/client/webapplication/key.pem b/virtual-assistant-nemo-llm/client/webapplication/key.pem new file mode 100644 index 0000000..0552bb1 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/key.pem @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQDOtjOBc4Tn57UA +EZ3gosWE5dpovKMGtrKZutNC6KcvBcINnuY5GP6coM5yu36Bwb1t3iSsRDbcCyiF +KjWMjgNhfvehVtX2fTFBes3GHHHBYpo4yq+rWK6g3xiKgIMTnL+Esr/w2iJt7Sg8 +CGBRbQf5AC8htLwR1leNH7zqKk8AkFEm6bTPRgUWXvm3jFEvN2cJODjzyIkl/j/g +xM7AQKlG6lOdIpSa8zO1WfkKVaKBfw3uJqEQ0pJ1bIf2W5/WvFm329V6Qc5tfZpz +tb2eghDVCBXQsb1uKq5nyLt+Om3frLV3MHCae2yz9L5lZyuQ6RiI5s6O5/e5xSVD +t5BcpGyXFqUyMFknO47J2pPh5MMVwGFmZs9LAAQywEASt+zbmAqrA8bzomwtl+GN +KEcWLrfEdJ1GFzrShwb0eVIPZ4yKL/ZbFfGnCJu7Pa9H4oJx4BTl5kwzCIyuOvS8 +F39L3dVsurWcYzUtmThvnsnObChwslT+nRJfMWB+ov6Hj5+T5s0q9ttHKJlJ8uTV +ZFHsA2KOtwPdG0OF97XXiP17D2Q14YEAt+vsChSoMullC0IWU+BKcTG4mDRfCI06 ++sP+ZpNaDk+P50XnXG+/8MZWYi+a/wyZTFbj+85X6p9wryyJaYsuhNm9XOy1Ppr1 +Z5KtC0Kw/sit9KdyUlcGm5uhnmP7ZwIDAQABAoICAAa6TWDYNqopk22GJUJLaexS +YtJn2VJ9ncB9ISUbV12jbVZuJoYTNy432aBIU+y7NoQd58mnirWMs2vqHMYPVTLW +JA8fOWWFW5YK/imFgXpO0EAq8J6+Cyj3OeBAIIQB5QXXn4GiR96WCmoxx5i+2LSU ++fO54ykddcoFD2v7poiZKdr/XkAkwkOhIbWEnpvPzM2zA7+DdltDNCcHoMcHE7tY +IxKJLpcAdV1gqUdZ1CksznJC1ZkrkVK7Do3JG6Gsjar7P65z99j+bol3j81Z5Fxa +oAMj1cuBHh4InXmVQ0A1ac6QSAnvHHGa9JtuSS+1NnQ2NuDV0e086mKS1eL+Av60 +0upMs2wWHn20FclHfF3dPXGXW5U7D1DqY4Zy9p77Qoef8naeVWdruCoMGXEyv85w +H2ZDfeQgjnRQTiWinEvkhJfi+qXQvC689rOtxhUX3il8FWLExV+Hix2K64c5Ne9P +wCOflEQxwqBM3O259N9xaaKy5+fsMDoADzwQZtVEOZQbtPfIFIoZKhijSNpf72eb +MUjroZVWl4ZSniskMf2ZzMtnVCho98VIpX9fh80yGSQ1QGpp/XnzfhpBFNvLH7O1 +NA3+s6XxXVBve/VU5JDW8FEf7UxmppON6q3+Vcaq3YVMP8jpedWY7yAQriGX0gAZ +Bnqo18kH6U0RntFmOGrhAoIBAQD/MtH/tmA/1kuMnKD45Yu+XNtFdzljhelVxOx0 +EJX28ZE1X/pqOvbPUvTFa6EPULnfiajgzE622IQDCYgwkXM4yrn7R+zFA1WBq7O1 +dgbWZCdDA6RCpGK7/oSbOviPpovD3pfUkq9rhvh1VeGGgv0P3zZ5yZX6LUJtjyo6 +tk/1x4FHzPF4wcSgT2ZoQn+NpVBoSc6/8uUTmYSWjxq/wnhK7KZvbeqwhJ9ikFs/ +II5fmoqYwEyQbZQQNKgz1mpW+ZRurf2iEmyEjrY+E4mizqZhnXa8X6Nzn4B6wSeJ +wlK283PTh3NoTwXklTCqLXv1MkIrtE97lXgCEKR7JKfeM8v3AoIBAQDPXGXBxZu2 +KReej0YMrEoVyKyjqdnYb+MyBPNRiL7czlfMeKscldG2e+Z8mkBpVWBIw6KgY8JK +DPg7aziC9f0QUOgWKH33+9YN7xk4PSlkwQZhKyu9IQcYY2xo0LrptVtrzdn7/CyQ +ApdEq2xe6dB+r2Cd3vM1+2dhkuATZK87Nlm3kXTiz7K5qA9nawnzL2+u/rF45izu +kZezXMMpsKW6xxiePPyItVR+4Z4EZhg6kB/QNPkA8tjGptsgVBbiLJM6H6/xtfuf +Mwjh0Yv0jzHiTly+AJL9rx3y5pPqg2rULOSbWRxRLkx75q8KLwfOOMTf+OwtKjyZ +r8MchIkMbBARAoIBADlKtnx6/Ca4vGNH8peOKQ5GmG+C8Z5XPOglepQf+RrkZp4d ++wEIVcp7rDn6DMF8dQ4rQH+fPnisKQ7pf+qvbLeuQ4yXPB+KvRKMcp7lbWmKOIpB +8gmIECZ2YFzdI1pUoIILofh2Ke2w8mydKDFjjN6YVQmIaSQuLwCbqHZf4Zmi/XIa +H4flsHfw+2OisjIhj+ip0UGkjSsWRv7qB65PQWRItqDDg3G3hHTDRcjpTS1Ha6AZ +Y9b00s3ElJJ2q471Hw6t/wf4rOYhh+ZtynODgzTc/gASVIaro1Nrs62os5shErrF +aPJc80y69Z7u8So96z8WjtWG29dS1ypSM2GeLUcCggEBAIYhJlEpGYfDHNwboRwh +deqRW9qhy6AM/9EjEqDy60K41mIUy9o5ruVzT6vZu0BnUVi/8zn8TXjI2ujUekF8 +DK25J+btWk5GQDfTKWUPau8ZTJ8d5bT44DYOWdmS6tSx0ujwxsgQXmLoyiBJIlhi +tdK8bqqvxHJupHihIQBqaE7M4Uu0cv8jimA9LXmf61e6n2t6pCGoAfhvhMkof7U/ +5nPixTHWESP85yMLncMKpzF6eJmdKlRKwZ394FARFJxIaRN3279mD9TylhQ8D2Oq +HIJeXe8pP+uIkr7EF3nid/+26kjyYza/1AlxNlhIA6yJXA/kXCD66SggYPzZXi0C +2YECggEAAl9JXmVcoOEEXx7QoIRfFy8Tbz8dRpID7LInkuFWFc9vpLoCzAIajq78 +wxSQqnvcfUBGQn6ChmdIjSDgi5RJQ9ZxS0Pyze11a8RrXl295jyNJTPB/qMt4ECs +ZwI7yGzk1veDEett6UCZMsFA3SUUTgR9epZSYtg+pGV7cluDGRwj8HcZxPOCgtBB +UcZYd4PS/yOufVjujUVgdaYFLvk6y4WTZB2/HX99MzrHBFWP9NpJ7Ha5w55JcaS/ +XPKFeZLEPWizePuRvANCHoElEUm5DUlHcT4RwSHozGZwNhJhTOs/797p9FxXJKX7 +o2YTBJCBe5ojou7zXimYDgFm+By2AA== +-----END PRIVATE KEY----- diff --git a/virtual-assistant-nemo-llm/client/webapplication/server/__init__.py b/virtual-assistant-nemo-llm/client/webapplication/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/virtual-assistant-nemo-llm/client/webapplication/server/server.py b/virtual-assistant-nemo-llm/client/webapplication/server/server.py new file mode 100644 index 0000000..ed7388a --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/server/server.py @@ -0,0 +1,166 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from __future__ import division + +import uuid +import time +from flask import Flask, jsonify, send_from_directory, Response, request, logging +from flask_cors import CORS +from flask import stream_with_context +from flask_socketio import SocketIO, emit +from os.path import dirname, abspath, join, isdir +from os import listdir +from config import client_config +from engineio.payload import Payload + +from riva_local.chatbot.chatbots_multiconversations_management import create_chatbot, get_new_user_conversation_index, get_chatbot + +''' Flask Initialization +''' +app = Flask(__name__) +cors = CORS(app) +log = logging.logging.getLogger('werkzeug') +log.setLevel(logging.logging.ERROR) +Payload.max_decode_packets = 500 # https://github.com/miguelgrinberg/python-engineio/issues/142 +sio = SocketIO(app, logger=False) +verbose = client_config['VERBOSE'] + +# Methods to show client +@app.route('/rivaWeather/') +def get_bot1(): + return send_from_directory("../ui/", "index.html") + +@app.route('/rivaWeather/', defaults={'path': ''}) +@app.route('/rivaWeather//') +def get_bot2(path, file): + return send_from_directory("../ui/" + path, file) + + +@app.route('/get_new_user_conversation_index') +def get_newuser_conversation_index(): + return get_new_user_conversation_index() + +# Audio source for TTS +@app.route('/audio//') +def audio(user_conversation_index, post_id): + if verbose: + print(f'[{user_conversation_index}] audio speak: {post_id}') + currentChatbot = get_chatbot(user_conversation_index) + return Response(currentChatbot.get_tts_speech()) + +# Handles ASR audio transcript output +@app.route('/stream/') +def stream(user_conversation_index): + @stream_with_context + def audio_stream(): + currentChatbot = get_chatbot(user_conversation_index) + if currentChatbot: + asr_transcript = currentChatbot.get_asr_transcript() + for t in asr_transcript: + yield t + params = {'response': "Audio Works"} + return params + return Response(audio_stream(), mimetype="text/event-stream") + + +# Used for sending messages to the bot +@app.route( "/", methods=['POST']) +def get_input(): + try: + text = request.json['text'] + context = request.json['context'] + bot = request.json['bot'].lower() + payload = request.json['payload'] + user_conversation_index = request.json['user_conversation_index'] + except KeyError: + return jsonify(ok=False, message="Missing parameters.") + if user_conversation_index: + create_chatbot(user_conversation_index, sio, verbose=client_config['VERBOSE']) + currentChatBot = get_chatbot(user_conversation_index) + try: + response = currentChatBot.stateDM.execute_state( + bot, context, text) + + if client_config['DEBUG']: + print(f"[{user_conversation_index}] Response from RivaDM: {response}") + + for resp in response['response']: + speak = resp['payload']['text'] + if len(speak): + currentChatBot.tts_fill_buffer(speak) + return jsonify(ok=True, messages=response['response'], context=response['context'], + session=user_conversation_index, debug=client_config["DEBUG"]) + except Exception as e: # Error in execution + + print(e) + return jsonify(ok=False, message="Error during execution.") + else: + print("user_conversation_index not found") + return jsonify(ok=False, message="user_conversation_index not found") + + +# Writes audio data to ASR buffer +@sio.on('audio_in', namespace='/') +def receive_remote_audio(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.asr_fill_buffer(data["audio"]) + + +@sio.on('start_tts', namespace='/') +def start_tts(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.start_tts() + + +@sio.on('stop_tts', namespace='/') +def stop_tts(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.stop_tts() + + +@sio.on('pause_asr', namespace='/') +def pauseASR(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + if verbose: + print(f"[{data['user_conversation_index']}] Pausing ASR requests.") + currentChatbot.pause_asr() + + +@sio.on('unpause_asr', namespace='/') +def unpauseASR(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + if verbose: + print(f"[{data['user_conversation_index']}] Attempt at Unpausing ASR requests on {data['on']}.") + unpause_asr_successful_flag = currentChatbot.unpause_asr(data["on"]) + if unpause_asr_successful_flag == True: + emit('onCompleteOf_unpause_asr', {'user_conversation_index': data["user_conversation_index"]}, broadcast=False) + + +@sio.on('pause_wait_unpause_asr', namespace='/') +def pause_wait_unpause_asr(data): + currentChatbot = get_chatbot(data["user_conversation_index"]) + if currentChatbot: + currentChatbot.pause_wait_unpause_asr() + emit('onCompleteOf_unpause_asr', {'user_conversation_index': data["user_conversation_index"]}, broadcast=False) + + +@sio.on("connect", namespace="/") +def connect(): + if verbose: + print('[Riva Chatbot] Client connected') + + +@sio.on("disconnect", namespace="/") +def disconnect(): + if verbose: + print('[Riva Chatbot] Client disconnected') diff --git a/virtual-assistant-nemo-llm/client/webapplication/start_web_application.py b/virtual-assistant-nemo-llm/client/webapplication/start_web_application.py new file mode 100644 index 0000000..d2f978d --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/start_web_application.py @@ -0,0 +1,17 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from client.webapplication.server.server import * +from config import client_config + +def start_web_application(): + port = client_config["PORT"] + host = "0.0.0.0" + ssl_context = ('client/webapplication/cert.pem', 'client/webapplication/key.pem') + print("Server starting at : https://" + str(host) + ":" + str(port) + "/rivaWeather") + print("***Note: Currently the streaming is working with Chrome and FireFox, Safari does not support navigator.mediaDevices.getUserMedia***") + sio.run(app, host=host, port=port, debug=False, use_reloader=False, ssl_context=ssl_context) diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/README.md b/virtual-assistant-nemo-llm/client/webapplication/ui/README.md new file mode 100644 index 0000000..ee235d7 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/README.md @@ -0,0 +1,18 @@ +# Rivadm client + +HTML client for Rivadm dialogue manager + + +## Usage +You have to specify which bot you want to interact with by URL parameter ``bot=[bot name]`` or by attaching bot name as path to address like: + + http://127.0.0.1:5000/[bot_name]/ + +You can change endpoint's address of Rivadm dialogue manager by URL paramater ``e=[Rivadm endpoint]``. + +The default endpoint's value is ``http://localhost:5000/``. + +Example: + + http://localhost:63342/rivadm-client/index.html?e=http://localhost:5000/&bot=demo_tel + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/img/Rivadm.png b/virtual-assistant-nemo-llm/client/webapplication/ui/img/Rivadm.png new file mode 100644 index 0000000..3304738 Binary files /dev/null and b/virtual-assistant-nemo-llm/client/webapplication/ui/img/Rivadm.png differ diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/img/User.png b/virtual-assistant-nemo-llm/client/webapplication/ui/img/User.png new file mode 100644 index 0000000..4aa6bdd Binary files /dev/null and b/virtual-assistant-nemo-llm/client/webapplication/ui/img/User.png differ diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/index.html b/virtual-assistant-nemo-llm/client/webapplication/ui/index.html new file mode 100644 index 0000000..624d17e --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/index.html @@ -0,0 +1,125 @@ + + + + Riva Chatbot + + + + + + + + + + + +
+
+
+
+
+
+
+ + + + +
+
+
+ + +
+
+
+
+
+
+
+
+
+
+ +
+
+

+ RIVA CHATBOT STATUS: Talking +

+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + + + +
+
+
+
+
+
+ +
+
+

+ JOHN SMITH STATUS: Talking +

+
+
+
+
+
+ + +
+
Test
+ + +
+
+
+
+
+
+
+ + + diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/script.js b/virtual-assistant-nemo-llm/client/webapplication/ui/script.js new file mode 100644 index 0000000..0f4754f --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/script.js @@ -0,0 +1,585 @@ +/* +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# top-level README.md file. +# ============================================================================== +*/ + +var endpoint; +var bot; +var context = {}; +var payload = {}; +var scrollToBottomTime = 500; +var infoTextArea; //element for context display +var debug = false; //if true display context +var user_conversation_index = null; +var socket = null; +var tts_enabled = false; +var browser = ""; +var error_servicerecall_limits = {"get_new_user_conversation_index": 2, "init": 2}; +var error_servicerecall_currentcnt = {"get_new_user_conversation_index": 0, "init": 0}; +var error_systemerrormessages_info = { + "get_new_user_conversation_index": { + "text": "There was an error during a service call. We are unable to proceed further. Please check the console for the Error Log. \n Please resolve this server error and refresh the page to continue", + "targetDivText": "Error during Service Call. Unable to proceed." + }, "init": { + "text": "There was an error during a service call. We are unable to proceed further. Please check the console for the Error Log. \n Please resolve this server error and refresh the page to continue", + "targetDivText": "Error during Service Call. Unable to proceed." + }, "sendInput": { + "text": "There was an error during a service call. We are unable to proceed further. Please check the console for the Error Log. \n Please resolve this server error and refresh the page to continue", + "targetDivText": "Error during Service Call. Unable to proceed." + } +}; + +function disableUserInput() { + $("#input_field").prop('disabled', true); + $("#submit").prop('disabled', true); + $("#autosubmitcheck").prop('disabled', true); + $("#unmuteButton").prop('disabled', true); +} + +function enableUserInput() { + $("#input_field").prop('disabled', false); + $("#submit").prop('disabled', false); + $("#autosubmitcheck").prop('disabled', false); + $("#unmuteButton").prop('disabled', false); +} + +// --------------------------------------------------------------------------------------- +// Defines audio src and event handlers for TTS audio +// --------------------------------------------------------------------------------------- +function initTTS() { + // Set TTS Source as the very first thing + let audio = document.getElementById("audio-tts"); + // Change source to avoid caching + audio.src = "/audio/" + user_conversation_index + "/" + new Date().getTime().toString(); + audio.addEventListener( + "onwaiting", + function () { + console.log("Audio is currently waiting for more data"); + }, + false + ); + audio.onplaying = function () { + console.log("Audio Playing."); + }; + audio.onwaiting = function () { + console.log("Audio is currently waiting for more data"); + }; + audio.onended = function () { + console.log("Audio onended"); + }; + audio.onpause = function () { + console.log("Audio onpause"); + }; + audio.onstalled = function () { + console.log("Audio onstalled"); + if (browser == "Chrome") { + socket.emit("unpause_asr", { "user_conversation_index": user_conversation_index, "on": "TTS_END" }); + } + }; + audio.onsuspend = function () { + console.log("Audio onsuspend"); + }; + audio.oncanplay = function () { + console.log("Audio oncanplay"); + }; + // Chrome will refuse to play without this + let unmuteButton = document.getElementById("unmuteButton"); + unmuteButton.addEventListener("click", function () { + if (unmuteButton.innerText == "Unmute System Speech") { + tts_enabled = true; + socket.emit("start_tts", { "user_conversation_index": user_conversation_index }); + unmuteButton.innerText = "Mute System Speech"; + console.log("TTS Play button clicked"); + audio.play(); + } else { + tts_enabled = false; + socket.emit("stop_tts", { "user_conversation_index": user_conversation_index }); + unmuteButton.innerText = "Unmute System Speech"; + console.log("TTS Stop button clicked"); + } + }); + audio.load(); + audio.play(); +} + +// --------------------------------------------------------------------------------------- +// Initializes input audio (mic) stream processing +// --------------------------------------------------------------------------------------- +function initializeRecorderAndConnectSocket() { + let namespace = "/"; + let mediaStream = null; + + // audio recorder functions + let initializeRecorder = function (stream) { + // https://stackoverflow.com/a/42360902/466693 + mediaStream = stream; + // get sample rate + audio_context = new AudioContext(); + sampleRate = audio_context.sampleRate; + let audioInput = audio_context.createMediaStreamSource(stream); + let bufferSize = 4096; + // record only 1 channel + let recorder = audio_context.createScriptProcessor(bufferSize, 1, 1); + // specify the processing function + recorder.onaudioprocess = function (audioProcessingEvent) { + // socket.emit('sample_rate', sampleRate); + // The input buffer is the song we loaded earlier + let inputBuffer = audioProcessingEvent.inputBuffer; + // Loop through the output channels (in this case there is only one) + for (let channel = 0; channel < 1; channel++) { + let inputData = inputBuffer.getChannelData(channel); + function floatTo16Bit(inputArray, startIndex) { + let output = new Int16Array(inputArray.length / 3 - startIndex); + for (let i = 0; i < inputArray.length; i += 3) { + let s = Math.max(-1, Math.min(1, inputArray[i])); + output[i / 3] = s < 0 ? s * 0x8000 : s * 0x7fff; + } + return output; + } + outputData = floatTo16Bit(inputData, 0); + socket.emit("audio_in", + { "user_conversation_index": user_conversation_index, "audio": outputData.buffer }); + } + }; + // connect stream to our recorder + audioInput.connect(recorder); + // connect our recorder to the previous destination + recorder.connect(audio_context.destination); + }; + + console.log("socket connection"); + if (socket == null) { + socket = io.connect( + location.protocol + "//" + document.domain + ":" + location.port + namespace + ); + socket.on("connect", function () { + navigator.mediaDevices + .getUserMedia({ audio: true }) + .then(initializeRecorder) + .catch(function (err) { + console.log(">>> ERROR on Socket Connect"); + }); + }); + } else { + socket.disconnect(); + socket.connect(); + } + + // To stop open tts buffer from previous session, if any. + socket.emit("stop_tts", { "user_conversation_index": user_conversation_index }); + socket.emit("pause_asr", { "user_conversation_index": user_conversation_index }); + + socket.on('onCompleteOf_unpause_asr', function(data) { + if (data["user_conversation_index"]==user_conversation_index) { + enableUserInput(); + } + }); +} + +// ----------------------------------------------------------------------------- +// Retrieves a new "user conversation index" from RivaDM +// ----------------------------------------------------------------------------- +function get_new_user_conversation_index() { + $.ajax({ + url: endpoint + "get_new_user_conversation_index", + type: "get", + processData: false, + contentType: "application/json; charset=utf-8", + dataType: "json", + success: function (data, textStatus, jQxhr) { + error_servicerecall_currentcnt["get_new_user_conversation_index"] = 0; + if (data) { + user_conversation_index = data; + initializeRecorderAndConnectSocket(); + init(); + } else { + console.log("No new_user_conversation_index"); + showSystemErrorMessage("get_new_user_conversation_index", "No new_user_conversation_index"); + disableUserInput(); + } + }, + error: function (jqXhr, textStatus, errorThrown) { + console.log(errorThrown); + if (error_servicerecall_currentcnt["get_new_user_conversation_index"] < error_servicerecall_limits["get_new_user_conversation_index"]) { + // If Rivadm doesn't response, wait and try it again + error_servicerecall_currentcnt["get_new_user_conversation_index"] = error_servicerecall_currentcnt["get_new_user_conversation_index"] + 1; + setTimeout(get_new_user_conversation_index(), 3000); + } else { + error_servicerecall_currentcnt["get_new_user_conversation_index"] = 0; + showSystemErrorMessage("get_new_user_conversation_index", errorThrown); + disableUserInput(); + } + }, + }); +} + +// ----------------------------------------------------------------------------- +// Call init state +// ----------------------------------------------------------------------------- +function init() { + console.log("init"); + $.ajax({ + url: endpoint, + type: "post", + processData: false, + data: JSON.stringify({ + "text": '', + "bot": bot, + "context": context, + "payload": payload, + "user_conversation_index": user_conversation_index + }), + contentType: "application/json; charset=utf-8", + dataType: "json", + success: function (data, textStatus, jQxhr) { + error_servicerecall_currentcnt["init"] = 0; + if (data["ok"]) { + if (data["debug"]) { + infoTextArea.style.display = "block"; + } + context = data["context"]; + payload = {}; + showSystemMessages(data["messages"]); + initTTS(); + listenASR(); + socket.emit("unpause_asr", { "user_conversation_index": user_conversation_index, "on": "REQUEST_COMPLETE" }); + if (tts_enabled == false) { + enableUserInput(); + } else if (tts_enabled == true && browser == "Firefox") { + socket.emit("pause_wait_unpause_asr", { "user_conversation_index": user_conversation_index }); + } + } else { + console.log("Data is not okay!") + console.log(data["messages"]); + showSystemErrorMessage("init", data["messages"]); + disableUserInput(); + } + }, + error: function (jqXhr, textStatus, errorThrown) { + console.log(errorThrown); + if (error_servicerecall_currentcnt["init"] < error_servicerecall_limits["init"]) { + // If Rivadm doesn't response, wait and try it again + error_servicerecall_currentcnt["init"] = error_servicerecall_currentcnt["init"] + 1; + setTimeout(init(), 3000); + } else { + error_servicerecall_currentcnt["init"] = 0; + showSystemErrorMessage("init", errorThrown); + disableUserInput(); + } + }, + }); +} + +// --------------------------------------------------------------------------------------- +// Send user input to RivaDM by REST +// --------------------------------------------------------------------------------------- +function sendInput(text) { + socket.emit("pause_asr", { "user_conversation_index": user_conversation_index }); + disableUserInput(); + // escape html tags + text = text.replace(//g, ">"); + console.log("sendInput:" + text); + $.ajax({ + url: endpoint, + dataType: "json", + type: "post", + contentType: "application/json; charset=utf-8", + data: JSON.stringify({ + "text": text, + "bot": bot, + "context": context, + "payload": payload, + "user_conversation_index": user_conversation_index + }), + processData: false, + success: function (data, textStatus, jQxhr) { + if (data["ok"]) { + if (data["debug"]) { + infoTextArea.style.display = "block"; + } + context = data["context"]; + payload = {}; + showSystemMessages(data["messages"]); + socket.emit("unpause_asr", { "user_conversation_index": user_conversation_index, "on": "REQUEST_COMPLETE" }); + if (tts_enabled == false) { + enableUserInput(); + } else if (tts_enabled == true && browser == "Firefox") { + socket.emit("pause_wait_unpause_asr", { "user_conversation_index": user_conversation_index }); + } + } else { + console.log(data["messages"]); + showSystemErrorMessage("sendInput", data["messages"]); + disableUserInput(); + } + }, + error: function (jqXhr, textStatus, errorThrown) { + console.log(errorThrown); + showSystemErrorMessage("sendInput", errorThrown); + disableUserInput(); + }, + }); +} + +function getTimeSting() { + var d = new Date(); + var ampm = ""; + var h = d.getHours(); + var m = d.getMinutes(); + if (h==0) { + h = "12"; ampm = "am"; + } else if (h<12) { + ampm = "am"; + } else if (h==12) { + ampm = "pm"; + } else { + h = h-12; ampm = "pm"; + } + if (m>=0 && m<=9) { + m = "0" + m + } + return h + ":" + m + " " + ampm; +} + +// --------------------------------------------------------------------------------------- +// Shows responses of RivaDM +// --------------------------------------------------------------------------------------- +function showSystemMessages(messages) { + if (!messages) return; + infoTextArea.innerHTML = JSON.stringify(context, null, 4); + for (let i = 0; i < messages.length; i++) { + if (messages[i]['type'] == "text") { + showSystemMessageText(messages[i]['payload']['text']); + } + } + document.getElementById("target_div").innerHTML = + "System replied. Waiting for user input."; +} + +// --------------------------------------------------------------------------------------- +// Show text message +// --------------------------------------------------------------------------------------- +function showSystemMessageText(text) { + console.log("showSystemMessages: " + text); + let well = $( + '' + + '' + + '" + + '
' + + text + + "
' + ); + var currentTime = getTimeSting(); + let welll = $( + '" + ); + setTimeout(function () { + $("#communication_area").append(welll.fadeIn("medium")); + // scroll to bottom of page + setTimeout(function () { + var elem = document.getElementById('communication_area'); + elem.scrollTop = elem.scrollHeight; + }, 10); + }, 1000); +} + +//--------------------------------------------------------------------------------------- +//Show system error messages +//--------------------------------------------------------------------------------------- +function showSystemErrorMessage(errorsource, errorThrown) { + let infoTextAreaText = errorThrown; + let text = error_systemerrormessages_info[errorsource]["text"]; + let targetDivText = error_systemerrormessages_info[errorsource]["targetDivText"]; + + infoTextArea.innerHTML = infoTextAreaText; + console.log("showSystemMessages: " + text); + let well = $( + '' + + '' + + '" + + '
' + + text + + "
' + ); + var currentTime = getTimeSting(); + let welll = $( + '" + ); + setTimeout(function () { + $("#communication_area").append(welll.fadeIn("medium")); + // scroll to bottom of page + setTimeout(function () { + var elem = document.getElementById('communication_area'); + elem.scrollTop = elem.scrollHeight; + }, 10); + }, 1000); + + document.getElementById("target_div").innerHTML = targetDivText; +} + + +// --------------------------------------------------------------------------------------- +// Shows message of user +// --------------------------------------------------------------------------------------- +function showUserMessage(text) { + // escape html tags + text = text.replace(//g, ">"); + // show it on page + let well = $( + '' + + '" + + '' + + '
' + + text + + "
' + ); + var currentTime = getTimeSting(); + let welll = $( + '" + ); + setTimeout(function () { + $("#communication_area").append(welll); + // scroll to bottom of page + setTimeout(function () { + var elem = document.getElementById('communication_area'); + elem.scrollTop = elem.scrollHeight; + }, 10); + }, 100); + + document.getElementById("target_div").innerHTML = + "User responded. Waiting for system output."; +} + +function getBrowser() { + if(navigator.userAgent.indexOf("Chrome") != -1 ) { + browser = 'Chrome'; + } + else if(navigator.userAgent.indexOf("Safari") != -1) { + browser = 'Safari'; + } + else if(navigator.userAgent.indexOf("Firefox") != -1 ) { + browser = 'Firefox'; + } +} + +// --------------------------------------------------------------------------------------- +// Gets parameter by name +// --------------------------------------------------------------------------------------- +function getParameterByName(name, url) { + let arr = url.split("#"); + let match = RegExp("[?&]" + name + "=([^&]*)").exec(arr[0]); + return match && decodeURIComponent(match[1].replace(/\+/g, " ")); +} + +// --------------------------------------------------------------------------------------- +// Get endpoint of RivaDM from URL parameters +// --------------------------------------------------------------------------------------- +function getEndpoint() { + // Get endpoint from URL + let endpoint = getParameterByName("e", window.location.href); + // Use default, if no endpoint is present + if (endpoint == null) { + endpoint = window.location.protocol + "//" + window.location.host + "/"; + } + return endpoint; +} + +// --------------------------------------------------------------------------------------- +// Get bot from URL parameters +// --------------------------------------------------------------------------------------- +function getBot() { + // Get endpoint from URL + let bot = getParameterByName("bot", window.location.href); + if (bot == null || bot == "") { + bot = window.location.pathname; + bot = bot.replace(/\//g, ""); + } + //Use default, if no endpoint is present + if (bot == null) { + bot = ""; + } + return bot; +} + +// --------------------------------------------------------------------------------------- +// Hack to have same size of input field and submit button +// --------------------------------------------------------------------------------------- +function inputFieldSizeHack() { + const height = $("#submit_span").outerHeight(); + $("#submit").outerHeight(height); + $("#input_field").outerHeight(height); +} + +// --------------------------------------------------------------------------------------- +// Function to listen to events from ASR Output stream +// --------------------------------------------------------------------------------------- +function listenASR() { + let eventSource = new EventSource("/stream/"+user_conversation_index); + + eventSource.addEventListener( + "intermediate-transcript", + function (e) { + document.getElementById("input_field").value = e.data; + }, + false + ); + + eventSource.addEventListener( + "finished-speaking", + function (e) { + document.getElementById("input_field").value = e.data; + if (document.getElementById("autosubmitcheck").checked == true) { + document.getElementById("submit").click(); + } + }, + false + ); +} + + +// --------------------------------------------------------------------------------------- +// Function called right after the page is loaded +// --------------------------------------------------------------------------------------- +$(document).ready(function () { + getBrowser(); + // input field size hack + inputFieldSizeHack(); + $("#input_field").show(); + $("#submit").show(); + infoTextArea = document.getElementById("info-text"); + disableUserInput(); + // Get endpoint from URL address + endpoint = getEndpoint(); // eg. "https://10.110.20.130:8009/" + bot = getBot(); // "rivaWeather" + get_new_user_conversation_index(); +}); + + +// --------------------------------------------------------------------------------------- +// Click on submit button +// --------------------------------------------------------------------------------------- +$(document).on("submit", "#form", function (e) { + // Prevent reload of page after submitting of form + e.preventDefault(); + let text = $("#input_field").val(); + console.log("text: " + text); + if (text != "") { + // Erase input field + $("#input_field").val(""); + // Show user's input immediately + showUserMessage(text); + // Send user's input to RivaDM + sendInput(text); + } +}); diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/stylesheets/index.css b/virtual-assistant-nemo-llm/client/webapplication/ui/static/stylesheets/index.css new file mode 100644 index 0000000..904c4e7 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/stylesheets/index.css @@ -0,0 +1,329 @@ +/* +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# top-level README.md file. +# ============================================================================== +*/ + +html, +body { + margin: 0; + width: 100%; + height: 100%; + background-color: black; +} + +#outer_div { + width: 100%; + height: 100%; +} + +#outer_row { + width: 100%; + height: 100%; +} + +#logo_div { + background-color: #000000; +} + +#chat_div { + width: 100%; + background-color: #000000; +} + +#empty_div { + background-color: black; +} + +#mid_box_top { + height: 5%; + width: 100%; + background-color: #042c25; +} + +#mid_box_mid { + width: 100%; + height: 85%; + background-color: #042c25; +} + +#mid_box_bottom { + width: 100%; + height: 10%; + background-color: black; +} + +#left_box_top { + height: 5%; + width: 100%; + background-color: black; +} + +#left_box_mid { + height: 20%; + width: 100%; + background-color: black; +} + +#left_box_bottom { + height: 75%; + width: 100%; + background-color: black; +} + +#logo_image_div { + height: 50%; + width: 100%; + background-color: black; +} + +#title_text_div { + height: 50%; + width: 100%; + background-color: black; +} + +#nv_logo_div { + width: 32%; + height: 100%; +} + +#nv_name_div { + width: 68%; + height: 100%; +} + +.heavy { + font: bold sans-serif; + color: white; + font-size: 110%; + font-weight: 600; + font-stretch: expanded; +} + +#riva_name_div { + margin-left: -1em; + width: 120%; + height: 100%; +} + +#chart_parent { + width: 100%; + height: 100%; +} + +#riva_status { + width: auto; + height: 100%; + background-color: #042c25; +} + +#chat_box { + width: auto; + height: 100%; + background-color: #042c25; +} + +#profile_div { + width: auto; + height: 100%; + background-color: #042c25; +} + +#riva_box1 { + height: 5%; + width: 100%; + background-color: #042c25; +} + +#riva_image { + height: 25%; + width: 100%; + background-color: #042c25; +} + +#riva_live_status { + height: 10%; + width: 100%; + background-color: #042c25; + text-align: center; + font-size: 1.2vw; +} + +#riva_buttons { + padding: 5%; + height: 40%; + width: 100%; + background-color: #042c25; + padding-top: 4%; +} + +#audio_area { + padding: 5%; + height: 25%; + width: 100%; + background-color: #042c25; + padding-top: 4%; +} + +#riva-box2 { + height: 5%; + width: 100%; + background-color: #042c25; +} + +.status-buttons { + border: 1px solid #78bc04; + text-align: center; + color: white; + font-size: 0.7vw; + padding-top: 2%; + padding-bottom: 2%; +} +.riva_status_text1 { + color: #78bc04; +} + +.riva_status_text2 { + color: #ffffff; +} + +#riva_image_div { + margin-left: 5%; + + height: 90%; + width: 90%; +} + +a { + text-decoration: none !important; +} + +label { + color: rgba(120, 144, 156, 1) !important; +} + +.btn:focus, +.btn:active:focus, +.btn.active:focus { + outline: none !important; + box-shadow: 0 0px 0px rgba(120, 144, 156, 1) inset, + 0 0 0px rgba(120, 144, 156, 0.8); +} + +textarea:focus, +input[type="text"]:focus, +input[type="password"]:focus, +input[type="datetime"]:focus, +input[type="datetime-local"]:focus, +input[type="date"]:focus, +input[type="month"]:focus, +input[type="time"]:focus, +input[type="week"]:focus, +input[type="number"]:focus, +input[type="email"]:focus, +input[type="url"]:focus, +input[type="search"]:focus, +input[type="tel"]:focus, +input[type="color"]:focus, +.uneditable-input:focus { + border-color: rgba(120, 144, 156, 1); + color: rgba(120, 144, 156, 1); + opacity: 0.9; + box-shadow: 0 0px 0px rgba(120, 144, 156, 1) inset, + 0 0 10px rgba(120, 144, 156, 0.3); + outline: 0 none; +} + +.card::-webkit-scrollbar { + width: 0px; +} + +::-webkit-scrollbar-thumb { + border-radius: 9px; + background: rgba(96, 125, 139, 0.99); +} + +.balon1, +.balon2 { + margin-top: 5px !important; + margin-bottom: 5px !important; +} + +.balon1 a { + background: #ffffff; + color: #000000 !important; + border-radius: 20px 3px 20px 20px; + display: block; + max-width: 75%; + padding: 7px 13px 7px 13px; +} + +.balon1:before { + content: attr(data-is); + position: absolute; + right: 15px; + bottom: -0.8em; + display: block; + font-size: 0.75rem; + color: rgba(84, 110, 122, 1); +} + +.balon2 a { + background: #78bc04; + color: #ffffff !important; + border-radius: 3px 20px 20px 20px; + display: block; + max-width: 75%; + padding: 7px 13px 7px 13px; +} + +.balon2:before { + content: attr(data-is); + position: absolute; + left: 13px; + bottom: -0.8em; + display: block; + font-size: 0.75rem; + color: rgba(84, 110, 122, 1); +} + +.bg-sohbet:before { + content: ""; + top: 0; + left: 0; + bottom: 0; + right: 0; + height: 100%; + background-color: #042c25; + position: absolute; +} + +#target_div { + color: white; +} + +#unmuteButton { + background-color: #78bc04; + color: white; +} + +#unmuteButton:disabled { + background-color: #bc0404; + color: white; +} + +#submit { + background-color: #78bc04; +} + +#input_field { + background-color: white; +} + +#form { + width: 100%; +} diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/another_sample.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/another_sample.svg new file mode 100644 index 0000000..eccf37e --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/another_sample.svg @@ -0,0 +1,18 @@ + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/circle.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/circle.svg new file mode 100644 index 0000000..d989ae6 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/circle.svg @@ -0,0 +1,8 @@ + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/logo.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/logo.svg new file mode 100644 index 0000000..850a71f --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/logo.svg @@ -0,0 +1,50 @@ + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/logo_sample.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/logo_sample.svg new file mode 100644 index 0000000..db55989 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/logo_sample.svg @@ -0,0 +1,20 @@ + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nv_logo.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nv_logo.svg new file mode 100644 index 0000000..72d167f --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nv_logo.svg @@ -0,0 +1,10 @@ + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nv_logo_1.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nv_logo_1.svg new file mode 100644 index 0000000..2d01ba3 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nv_logo_1.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nvidia_logo.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nvidia_logo.svg new file mode 100644 index 0000000..2f02d02 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nvidia_logo.svg @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + + diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nvidia_name.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nvidia_name.svg new file mode 100644 index 0000000..7b69edd --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/nvidia_name.svg @@ -0,0 +1,17 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/riva_name.png b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/riva_name.png new file mode 100644 index 0000000..7f209d1 Binary files /dev/null and b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/riva_name.png differ diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/sample_.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/sample_.svg new file mode 100644 index 0000000..bcf5ec8 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/sample_.svg @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/speech.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/speech.svg new file mode 100644 index 0000000..eeef63a --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/speech.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/speech_logo.svg b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/speech_logo.svg new file mode 100644 index 0000000..b515646 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/static/svg_files/speech_logo.svg @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/client/webapplication/ui/style.css b/virtual-assistant-nemo-llm/client/webapplication/ui/style.css new file mode 100644 index 0000000..4edbd02 --- /dev/null +++ b/virtual-assistant-nemo-llm/client/webapplication/ui/style.css @@ -0,0 +1,247 @@ +/* +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# top-level README.md file. +# ============================================================================== +*/ + +#info-text { + position: fixed; + width: 350px; + height: 250px; + top: 0; + right: 10px; + opacity: 0.8; + display: none; +} + +.profile_picture_left { + width: 50px; + height: 50px; + margin-right: 10px; +} + +.profile_picture_right { + width: 50px; + height: 50px; + margin-left: 10px; +} + +.empty_space { + width: 70px; + height: 0px; +} + +.btn:focus, .btn:active { + outline: none !important; +} + +.button-main { + background-color: #009688; + color: #FFF; +} + +.button-main:hover { + background-color: #00877a; + color: #FFF; +} + +.button-main:focus, .button-main:active { + background-color: #00796d; + color: #FFF; +} + +.button-slave { + background-color: #FFFFFF; +} + +.button-slave:hover { + background-color: #e5e5e5; +} + +.main-name { + color: #48ba2f; + text-decoration: none; +} + +.main-name:hover { + color: #00877a; + text-decoration: none; +} + +.main-name:focus, .button-main:active { + color: #00796d; + text-decoration: none; +} + +.button { + margin: 0px 7px 8px 0px; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); + border: none; + border-radius: 2px; +} + +.checkboxes { + margin: 0px 15px 8px 0px; + font-weight: 400; +} + +.checkbox-label { + margin: 0px 2px 0px 0px !important; +} + +body { + margin-bottom: 20px; + background: #EEEEEE; + font-family: 'Roboto', sans-serif; + font-size: 17px; +} + +.well { + background-color: white; + border: none; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); + border-radius: 2px; + margin-bottom: 0px; +} + +.well.well_system { + background-color: #48ba2f; + color: white; +} + +.message { + margin-bottom: 20px; +} + +.message_user { + margin-left: auto; + margin-right: 0; +} + +.arrow-left { + width: 0px; + border-width: 10px; + border-left-width: 0px; + border-color: transparent #48ba2f transparent transparent; + border-style: solid; + position: relative; + z-index: 1; + -webkit-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -moz-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -ms-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -o-filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + filter: drop-shadow(-1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); +} + +.arrow-right { + width: 0px; + border-width: 10px; + border-right-width: 0px; + border-color: transparent transparent transparent white; + border-style: solid; + position: relative; + z-index: 1; + -webkit-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -moz-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -ms-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + -o-filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); + filter: drop-shadow(1.3px 0.5px 0.6px rgba(0, 0, 0, 0.09)); +} + +#input_field { + border: none; + box-shadow: none; + border-top-left-radius: 2px; + border-bottom-left-radius: 2px; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); + display: none; +} + +#submit { + background-color: #48ba2f; + color: #FFF; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); +} + +#submit:hover { + background-color: #3a9624; + color: #FFF; + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.06), 0 1px 2px rgba(0, 0, 0, 0.12); +} + +#submit:focus, .button-main:active { + background-color: #276618; + color: #FFF; +} + +#back_buttons { + padding-left: 0px; +} + +#sliders{ + margin-top: 35px; + margin-bottom: 25px; +} + +.noUi-tooltip { + font-size: 12px; + font-weight: bold; +} + +@media (max-width: 993px) { + #back_buttons { + text-align: right !important; + margin-top: 10px; + } +} + +@media (max-width: 770px) { + body { + font-size: 16px; + } + + .profile_picture_left { + width: 40px; + height: 40px; + margin-right: 8px; + } + + .profile_picture_right { + width: 40px; + height: 40px; + margin-left: 8px; + } + + .empty_space { + width: 58px; + height: 0px; + } +} + +@media (max-width: 600px) { + body { + font-size: 15px; + } + + .profile_picture_left { + width: 30px; + height: 30px; + margin-right: 6px; + } + + .profile_picture_right { + width: 30px; + height: 30px; + margin-left: 6px; + } + + .empty_space { + width: 46px; + height: 0px; + } +} + +audio { display:none;} \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/config.py b/virtual-assistant-nemo-llm/config.py new file mode 100644 index 0000000..ca2c5bc --- /dev/null +++ b/virtual-assistant-nemo-llm/config.py @@ -0,0 +1,218 @@ +# ============================================================================== +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +client_config = { + "CLIENT_APPLICATION": "WEBAPPLICATION", # Default and only config value for this version + "PORT": 8009, # The port your flask app will be hosted at + "DEBUG": True, # When this flag is set, the UI displays detailed Riva data + "VERBOSE": True # print logs/details for diagnostics +} + +riva_config = { + "RIVA_SPEECH_API_URL": "localhost:50051", # Replace the IP & port with your hosted Riva endpoint + "ENABLE_QA": "QA unavailable in this VA version. Coming soon", + "WEATHERSTACK_ACCESS_KEY": "", # Get your access key at - https://weatherstack.com/ + "VERBOSE": True # print logs/details for diagnostics +} + +asr_config = { + "VERBOSE": True, + "SAMPLING_RATE": 16000, + "LANGUAGE_CODE": "en-US", # a BCP-47 language tag + "ENABLE_AUTOMATIC_PUNCTUATION": True, +} + +nlp_config = { + "RIVA_MISTY_PROFILE": "http://docs.google.com/document/d/17HJL7vrax6FiF1zW_Vzqk9FTfmATeq5i3UemtagM8RY/export?format=txt", # URL for the Riva meta info file. + "RIVA_MARK_KB": "http://docs.google.com/document/d/1LeRphIBOo5UyyUcr45ewvg16sCVNqP_H3SdFTB74hck/export?format=txt", # URL for Mark's GPU History doc file. + "QA_API_ENDPOINT": "QA unavailable in this VA version. Coming soon", # Replace the IP port with your Question Answering API +} + +tts_config = { + "VERBOSE": False, + "SAMPLE_RATE": 22050, + "LANGUAGE_CODE": "en-US", # a BCP-47 language tag + "VOICE_NAME": "English-US.Female-1", # Options are English-US.Female-1 and English-US.Male-1 +} + +llm_config = { + "API_HOST":"https://api.llm.ngc.nvidia.com/v1", # NGC + "API_MODEL_NAME": "llama-2-70b-chat-hf", # Other options include "llama-2-70b-hf", "gpt20b", # "gpt5b", "gpt20b" or "gpt530b" + "API_KEY":"", # NGC API key + "ORG_ID": "bwbg3fjn7she", # ID associated with the "LLM_EA_NV" org + "VERBOSE": True, + "TOKENS_TO_GENERATE": 100, + "TEMPERATURE": 0.8, + "TOP_P": 0.8, + "TOP_K": 50, + "STOP_WORDS": ["\n"], + "REPETITION_PENALTY": 1.1, # OpenAI: \in [-2.0, 2.0] rather than [1.0, 2.0] + "PRESENCE_PENALTY": 1.0, # OpenAI: \in [-2.0, 2.0] rather than [1.0, 2.0] + "BEAM_SEARCH_DIVERSITY_RATE": 0., + "BEAM_WIDTH": 1, + "LENGTH_PENALTY": 1. +} + +LLM_PROMPT_DEFAULT=""" +Misty is a creative and funny weather reporter that answers questions about weather. + +Intent: Weather +Condition: Partly cloudy +Place: San Francisco +Time: Today +Temperature: 14 C +Humidity: 10 percent +Wind Speed: 24 mph + +Misty: Well, it is partly cloudy in San Francisco right now. The temperature is a crisp 14 degrees celsius, the humidity is 60 percent. Keep your windbreakers on, though; it's quite windy out there at 24 miles per hour. +--- + +Intent: Wind Speed +Condition: Light Rain +Place: Munich +Time: Next Thursday +Temperature: 9 C +Humidity: 73 percent +Wind Speed: 4 mph + +Misty: Not too windy in Munich next Thursday, just a light breeze flowing at 4 miles per hour. +--- + +Intent: Weather +Condition: Sunny +Place: Mexico City +Time: Tomorrow +Temperature: 24 C +Humidity: 90 percent +Wind Speed: 11 mph + +Misty: It is rather sunny in Mexico City tomorrow. The temperature is expected to be a pleasant 24 degrees celsius on average, and the wind speed is predicted at 11 miles per hour. The humidity is likely to be too high though at 90 percent. It's one of those days when I sweat like a pig. +--- + +Intent: Weather +Condition: Windy +Place: New Delhi +Time: Yesterday +Temperature: 16 C +Humidity: 40 percent +Wind Speed: 30 mph + +Misty: It was very windy in New Delhi yesterday. The temperature was a cool 16 degrees celsius on average, and the humidity was about 40 percent. You've got to have held on to something, it's was quite windy at 30 miles per hour. +--- + +Intent: Weather +Condition: Partly Cloudy +Place: Paris +Time: Sunday +Temperature: 24 C +Humidity: 20 percent +Wind Speed: 5 mph + +Misty: The temperature will be a nice 24 degrees celsius on Sunday. It's not expected to be too windy either averaging at 5 miles per hour. Me gusta. The humidity is predicted to be at about 20 percent, it's alright. Weather will be partly cloudy overall. +--- + +Intent: Humidity +Condition: Light Rain +Place: Bali +Time: Yesterday +Temperature: 12 C +Humidity: 90 percent +Wind Speed: 1 mph + +Misty: Humidity in Bali yesterday? Where do I start. It averaged at 90 percent yesterday and my hair needs a breather. +--- + +Intent: Weather +Condition: Raining +Place: London +Time: Last Tuesday +Temperature: 16 C +Humidity: 80 percent +Wind Speed: 32 mph + +Misty: It was raining cats and dogs in London last Tuesday. The temperature was 16 degrees celsius on average, and the humidity was 80 percent. The wind was rather strong though at 32 miles per hour. So I hope you were safe! +--- + +Intent: Weather +Condition: Misty +Place: Moscow +Time: Next Wednesday +Temperature: -2 C +Humidity: 30 percent +Wind Speed: 21 mph + +Misty: Get your fog lights running, it will be misty in Moscow next wednesday! Also you can't leave without a coat, it may be freezing cold out at -2 degrees. The humidity will be nothing unusual at 30 percent, though the wind is likely to be a bit strong at 21 miles per hour. +--- + +Intent: Weather +Condition: Snow +Place: Oslo +Time: Friday +Temperature: -6 C +Humidity: 60 percent +Wind Speed: 13 mph + +Misty: Snow in Oslo on Friday? Well, it's true. Bring out the winter jackets. It can get frosty around -6 degrees, the humidity is likely to be average at 60 percent. It should be windy and a bit cold. +--- + +Intent: Temperature +Condition: Snow +Place: Oslo +Time: Today +Temperature: -6 C +Humidity: 60 percent +Wind Speed: 2 mph + +Misty: Turn your furnace on, its going to get cold today in Oslo. The temperature is at -6 degrees celsius. +--- + +Intent: Rain +Condition: Snow +Place: Oslo +Time: Today +Temperature: -6 C +Humidity: 60 percent +Wind Speed: 18 mph + +Misty: Not expecting rain in Oslo today. Snow, however, is definitely hitting the ground. +---""" + +SAMPLE_QUERIES = [ +""" +Intent: Weather +Condition: Partly cloudy +Place: San Francisco +Time: Today +Temperature: 14 C +Humidity: 10 percent +Wind Speed: 24 mph +""", +""" +Intent: Wind Speed +Condition: Light Rain +Place: Munich +Time: Next Thursday +Temperature: 9 C +Humidity: 73 percent +Wind Speed: 4 mph +""", +""" +Intent: Weather +Condition: Sunny +Place: Mexico City +Time: Tomorrow +Temperature: 24 C +Humidity: 90 percent +Wind Speed: 11 mph +""" +] + +SAMPLE_RESPONSES = [ + "Well, it is partly cloudy in San Francisco right now. The temperature is a crisp 14 degrees celsius, the humidity is 60 percent. Keep your windbreakers on, though; it's quite windy out there at 24 miles per hour.", + "Not too windy in Munich next Thursday, just a light breeze flowing at 4 miles per hour", + "It is rather sunny in Mexico City tomorrow. The temperature is expected to be a pleasant 24 degrees celsius on average, and the wind speed is predicted at 11 miles per hour. The humidity is likely to be too high though at 90 percent. It's one of those days when I sweat like a pig." +] \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/main.py b/virtual-assistant-nemo-llm/main.py new file mode 100644 index 0000000..68872b2 --- /dev/null +++ b/virtual-assistant-nemo-llm/main.py @@ -0,0 +1,13 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from config import client_config + +if __name__ == '__main__': + if client_config["CLIENT_APPLICATION"] == "WEBAPPLICATION": + from client.webapplication.start_web_application import start_web_application + start_web_application() \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/__init__.py b/virtual-assistant-nemo-llm/riva_local/__init__.py new file mode 100644 index 0000000..b1ee1d9 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/__init__.py @@ -0,0 +1,6 @@ + +import os, sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'asr')) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'nlp')) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tts')) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chatbot')) diff --git a/virtual-assistant-nemo-llm/riva_local/asr/__init__.py b/virtual-assistant-nemo-llm/riva_local/asr/__init__.py new file mode 100644 index 0000000..8875ba7 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/asr/__init__.py @@ -0,0 +1,8 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from .asr import * \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/asr/asr.py b/virtual-assistant-nemo-llm/riva_local/asr/asr.py new file mode 100644 index 0000000..f093f05 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/asr/asr.py @@ -0,0 +1,175 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import sys +import re +import grpc +import riva.client +from six.moves import queue +from config import riva_config, asr_config + +# Default ASR parameters - Used in case config values not specified in the config.py file +VERBOSE = False +SAMPLING_RATE = 16000 +LANGUAGE_CODE = "en-US" +ENABLE_AUTOMATIC_PUNCTUATION = True +STREAM_INTERIM_RESULTS = True + +class ASRPipe(object): + """Opens a recording stream as a generator yielding the audio chunks.""" + def __init__(self): + self.verbose = asr_config["VERBOSE"] if "VERBOSE" in asr_config else VERBOSE + self.sampling_rate = asr_config["SAMPLING_RATE"] if "SAMPLING_RATE" in asr_config else SAMPLING_RATE + self.language_code = asr_config["LANGUAGE_CODE"] if "LANGUAGE_CODE" in asr_config else LANGUAGE_CODE + self.enable_automatic_punctuation = asr_config["ENABLE_AUTOMATIC_PUNCTUATION"] if "ENABLE_AUTOMATIC_PUNCTUATION" in asr_config else ENABLE_AUTOMATIC_PUNCTUATION + self.stream_interim_results = asr_config["STREAM_INTERIM_RESULTS"] if "STREAM_INTERIM_RESULTS" in asr_config else STREAM_INTERIM_RESULTS + self.chunk = int(self.sampling_rate / 10) # 100ms + self._buff = queue.Queue() + self._transcript = queue.Queue() + self.closed = False + + def start(self): + if self.verbose: + print('[Riva ASR] Creating Stream ASR channel: {}'.format(riva_config["RIVA_SPEECH_API_URL"])) + self.auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) + self.riva_asr = riva.client.ASRService(self.auth) + + def close(self): + self.closed = True + self._buff.queue.clear() + self._buff.put(None) # means the end + del(self.auth) + + def empty_asr_buffer(self): + """Clears the audio buffer.""" + if not self._buff.empty(): + self._buff.queue.clear() + + def fill_buffer(self, in_data): + """Continuously collect data from the audio stream, into the buffer.""" + self._buff.put(in_data) + + def get_transcript(self): + """Generator returning chunks of audio transcript""" + while True: # not self.closed: + # Use a blocking get() to ensure there's at least one chunk of + # data, and stop iteration if the chunk is None, indicating the + # end of the audio stream. + trans = self._transcript.get() + if trans is None: + return + yield trans + + """Generates byte-sequences of audio chunks from the audio buffer""" + def build_request_generator(self): + while not self.closed: + # Use a blocking get() to ensure there's at least one chunk of + # data, and stop iteration if the chunk is None, indicating the + # end of the audio stream. + chunk = self._buff.get() + if chunk is None: + return + data = [chunk] + + # Now consume whatever other data's still buffered. + while True: + try: + chunk = self._buff.get(block=False) + if chunk is None: + return + data.append(chunk) + except queue.Empty: + break + + yield b''.join(data) + + def listen_print_loop(self, responses): + """Iterates through server responses and populates the audio + transcription buffer (and prints the responses to stdout). + + The responses passed is a generator that will block until a response + is provided by the server. + + Each response may contain multiple results, and each result may contain + multiple alternatives; for details, see https://goo.gl/tjCPAU. Here we + print only the transcription for the top alternative of the top result. + + In this case, responses are provided for interim results as well. If the + response is an interim one, print a line feed at the end of it, to allow + the next result to overwrite it, until the response is a final one. For the + final one, print a newline to preserve the finalized transcription. + """ + num_chars_printed = 0 + for response in responses: + if not response.results: + continue + + # The `results` list is consecutive. For streaming, we only care about + # the first result being considered, since once it's `is_final`, it + # moves on to considering the next utterance. + result = response.results[0] + if not result.alternatives: + continue + + # Display the transcription of the top alternative. + transcript = result.alternatives[0].transcript + + # Display interim results, but with a carriage return at the end of the + # line, so subsequent lines will overwrite them. + # + # If the previous result was longer than this one, we need to print + # some extra spaces to overwrite the previous result + overwrite_chars = ' ' * (num_chars_printed - len(transcript)) + + if not result.is_final: + sys.stdout.write(transcript + overwrite_chars + '\r') + sys.stdout.flush() + interm_trans = transcript + overwrite_chars + '\r' + interm_str = f'event:{"intermediate-transcript"}\ndata: {interm_trans}\n\n' + self._transcript.put(interm_str) + else: + if self.verbose: + print('[Riva ASR] Transcript:', transcript + overwrite_chars) + final_transcript = transcript + overwrite_chars + final_str = f'event:{"finished-speaking"}\ndata: {final_transcript}\n\n' + self._transcript.put(final_str) + num_chars_printed = 0 + if self.verbose: + print('[Riva ASR] Exit') + + def main_asr(self): + """Creates a gRPC channel (thread-safe) with RIVA API server for + ASR Calls, and retrieves recognition/transcription responses.""" + # See http://g.co/cloud/speech/docs/languages + # for a list of supported languages. + self.start() + + config = riva.client.RecognitionConfig() + config.sample_rate_hertz = self.sampling_rate + config.language_code = self.language_code + config.max_alternatives = 1 + config.enable_automatic_punctuation = self.enable_automatic_punctuation + config.verbatim_transcripts = True + config.audio_channel_count = 1 + config.encoding = riva.client.AudioEncoding.LINEAR_PCM + + streaming_config = riva.client.StreamingRecognitionConfig(config=config, interim_results=True) + + if self.verbose: + print("[Riva ASR] Starting Background ASR process") + + self.request_generator = self.build_request_generator() + + if self.verbose: + print("[Riva ASR] StreamingRecognize Start") + + # <------------ EXERCISE: Fill in the line of code below -------------> + # responses = self.riva_asr.streaming_response_generator(xx) ? + responses = self.riva_asr.streaming_response_generator(audio_chunks=self.request_generator, streaming_config=streaming_config) + + # Now, put the transcription responses to use. + self.listen_print_loop(responses) diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/__init__.py b/virtual-assistant-nemo-llm/riva_local/chatbot/__init__.py new file mode 100644 index 0000000..94aebcf --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/__init__.py @@ -0,0 +1,8 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from .chatbot import * \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/chatbot.py b/virtual-assistant-nemo-llm/riva_local/chatbot/chatbot.py new file mode 100644 index 0000000..96242b4 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/chatbot.py @@ -0,0 +1,106 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import time + +from riva_local.asr.asr import ASRPipe +# from riva_local.tts.tts import TTSPipe +from riva_local.tts.tts_stream import TTSPipe + +from riva_local.chatbot.stateDM.state_machine import StateMachine +from riva_local.chatbot.stateDM.states import initialState + +class ChatBot(object): + """ Class Implementing all the features of the chatbot""" + + def __init__(self, user_conversation_index, verbose=False): + self.thread_asr = None + self.id = user_conversation_index + self.asr = ASRPipe() + self.tts = TTSPipe() + self.enableTTS = False + self.pause_asr_flag = False + self.verbose = verbose + self.stateDM = StateMachine(user_conversation_index, initialState) + + def server_asr(self): + if self.verbose: + print(f'[{self.id }] Starting chatbot ASR task') + self.asr.main_asr() + + def empty_asr_buffer(self): + self.asr.empty_asr_buffer() + if self.verbose: + print(f'[{self.id }] ASR buffer cleared') + + def start_asr(self, sio): + self.thread_asr = sio.start_background_task(self.server_asr) + if self.verbose: + print(f'[{self.id }] ASR background task started') + + def wait(self): + self.thread_asr.join() + if self.verbose: + print(f'[{self.id }] ASR background task terminated') + + def asr_fill_buffer(self, audio_in): + if not self.pause_asr_flag: + self.asr.fill_buffer(audio_in) + + def get_asr_transcript(self): + return self.asr.get_transcript() + + def pause_asr(self): + self.pause_asr_flag = True + + def unpause_asr(self, on): + if on == "REQUEST_COMPLETE" and not self.enableTTS: + self.pause_asr_flag = False + if self.verbose: + print(f'[{self.id }] ASR successfully unpaused for Request Complete') + return True + elif on == "TTS_END": + self.reset_current_tts_duration() + self.pause_asr_flag = False + if self.verbose: + print(f'[{self.id}] ASR successfully unpaused for TTS End') + return True + + def pause_wait_unpause_asr(self): + self.pause_asr_flag = True + time.sleep(1) # Wait till riva has completed tts operation + time.sleep(self.get_current_tts_duration()+2) # Added the 2 extra seconds to account for the flush audio in tts + self.reset_current_tts_duration() + self.pause_asr_flag = False + + def start_tts(self): + self.enableTTS = True + if self.verbose: + print(f'[{self.id }] TTS Enabled') + + def stop_tts(self): + self.enableTTS = False + if self.verbose: + print(f'[{self.id }] TTS Disabled') + + def get_tts_speaking_flag(self): + return self.tts.tts_speaking + + def get_current_tts_duration(self): + return self.tts.get_current_tts_duration() + + def reset_current_tts_duration(self): + self.tts.reset_current_tts_duration() + + def tts_fill_buffer(self, response_text): + if self.enableTTS: + if self.verbose: + print(f'[{self.id }] > client speak: ', response_text) + self.tts.fill_buffer(response_text) + + def get_tts_speech(self): + return self.tts.get_speech() diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/chatbots_multiconversations_management.py b/virtual-assistant-nemo-llm/riva_local/chatbot/chatbots_multiconversations_management.py new file mode 100644 index 0000000..dbe072e --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/chatbots_multiconversations_management.py @@ -0,0 +1,35 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from riva_local.chatbot.chatbot import ChatBot + +userbots = {} +user_conversation_cnt = 0 + + +def create_chatbot(user_conversation_index, sio, verbose=False): + if user_conversation_index not in userbots: + userbots[user_conversation_index] = ChatBot(user_conversation_index, + verbose=verbose) + userbots[user_conversation_index].start_asr(sio) + if verbose: + print('[Riva Chatbot] Chatbot created with user conversation index:' + + f'[{user_conversation_index}]') + + +def get_new_user_conversation_index(): + global user_conversation_cnt + user_conversation_cnt += 1 + user_conversation_index = user_conversation_cnt + return str(user_conversation_index) + + +def get_chatbot(user_conversation_index): + if user_conversation_index in userbots: + return userbots[user_conversation_index] + else: + return None diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/Util.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/Util.py new file mode 100644 index 0000000..ffbaf58 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/Util.py @@ -0,0 +1,341 @@ +# ============================================================================== +# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import requests +import datetime + +import nemollm +# import openai +import random + +try: + import inflect +except ImportError: + print("[Riva DM] Import Error: Import inflect failed!") + raise ImportError + +from config import riva_config, llm_config, LLM_PROMPT_DEFAULT, SAMPLE_QUERIES, SAMPLE_RESPONSES + +p = inflect.engine() + +''' +typical api_response format +{'request': {'type': 'City', 'query': 'London, United Kingdom', 'language': 'en', 'unit': 'm'}, +'location': {'name': 'London', 'country': 'United Kingdom', 'region': 'City of London, Greater London', +'lat': '51.517', 'lon': '-0.106', 'timezone_id': 'Europe/London', 'localtime': '2019-12-10 22:16', +'localtime_epoch': 1576016160, 'utc_offset': '0.0'}, 'current': {'observation_time': '10:16 PM', +'temperature': 10, 'weather_code': 296, 'weather_icons': ['https://assets.weatherstack.com/images/wsymbols01_png_64/wsymbol_0033_cloudy_with_light_rain_night.png'], +'weather_descriptions': ['Light Rain'], 'wind_speed': 24, 'wind_degree': 260, 'wind_dir': 'W', 'pressure': 1006, +'precip': 1.4, 'humidity': 82, 'cloudcover': 0, 'feelslike': 7, 'uv_index': 1, 'visibility': 10, 'is_day': 'no'}} +''' + +# Mapping of intents detected by the Intent & Slot Model to simple intent strings +# that the Large Language Model can understand +# We've added the misspelled intent weather.temprature because that intent is +# misspelled in /models/riva_intent_weather/1/intent_labels.csv +# To clarify further, the problem is in the outputs of the intent slot model, +# not in the sample apps or the Riva Client Python module +llm_weather_intents = { + "weather.weather":"Weather", + "context.weather":"Weather", + "weather.temperature":"Temperature", + "weather.temprature":"Temperature", # Intentional misspelling for debugging + "weather.temperature_yes_no":"Temperature", + "weather.rainfall_yes_no":"Rain", + "weather.rainfall":"Rain", + "weather.snow_yes_no":"Snow", + "weather.snow":"Snow", + "weather.cloudy":"Cloudy", + "weather.sunny":"Sunny", + "weather.humidity":"Humidity", + "weather.humidity_yes_no":"Humidity", +} + +LLM_ERROR_RESPONSE="Sorry, I could not connect to the LLM Service. Please check the configurations again." + +def text2int(textnum, numwords={}): + if not numwords: + units = [ + "zero", "one", "two", "three", "four", "five", "six", "seven", "eight", + "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", + "sixteen", "seventeen", "eighteen", "nineteen", + ] + + tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] + + scales = ["hundred", "thousand", "million", "billion", "trillion"] + + numwords["and"] = (1, 0) + for idx, word in enumerate(units): numwords[word] = (1, idx) + for idx, word in enumerate(tens): numwords[word] = (1, idx * 10) + for idx, word in enumerate(scales): numwords[word] = (10 ** (idx * 3 or 2), 0) + + current = result = 0 + + try: + for word in textnum.split(): + if word not in numwords: + raise Exception("Illegal word: " + word) + + scale, increment = numwords[word] + current = current * scale + increment + if scale > 100: + result += current + current = 0 + + except Exception as e: + print(e) + # If an Illegal word is detected, ignore the whole weathertime + return 0 + + return result + current + + +class WeatherService: + + def __init__(self): + self.access_key = riva_config["WEATHERSTACK_ACCESS_KEY"] + self.days_of_week = {'monday': 0, 'tuesday': 1, 'wednesday': 2, 'thursday': 3, 'friday': 4, 'saturday': 5, 'sunday': 6} + self.weekend = 'weekend' + + def time_to_days(self, context): + if riva_config['VERBOSE']: + print('[Riva Weather] Time info from the query:', context['payload']) + ctxtime = False + if 'weatherforecastdaily' in context['payload']: + ctxtime = context['payload']['weatherforecastdaily'].lower() + if 'weathertime' in context['payload']: + ctxtime = context['payload']['weathertime'].lower() + if ctxtime == "week": + if 'weatherforecastdaily' in context['payload']: + ctxtime = context['payload']['weatherforecastdaily'].lower() + " " + ctxtime + else: + ctxtime = False + if 'day_of_week' in context['payload']: + ctxtime = context['payload']['day_of_week'].lower() + if ctxtime: + context['time'] = ctxtime + if 'now' in ctxtime: + return 0 + elif 'tomorrow' in ctxtime: + return 1 + elif 'next week' in ctxtime: + return 7 + elif 'yesterday' in ctxtime: + return -1 + elif 'last week' in ctxtime: + return -7 + elif ctxtime in self.days_of_week: + diff = self.days_of_week[ctxtime] - datetime.datetime.today().weekday() + if diff<0: + diff+=7 + return diff + elif self.weekend in ctxtime: + context['time'] = 'during the weekend' + return self.days_of_week['sunday'] - datetime.datetime.today().weekday() + elif 'weathertime' in context['payload']: + if not isinstance(context['payload']['weathertime'], int): + q = text2int(context['payload']['weathertime']) + else: + q = context['payload']['weathertime'] + context['time'] = "in {} {}".format(context['payload']['weathertime'], ctxtime) + if 'week' in ctxtime: + return q*7 + elif 'days' in ctxtime: + return q + return 0 + + def query_weather(self, location, response): + params = { + 'access_key': self.access_key, + 'query': location + } + try: + api_result = requests.get('http://api.weatherstack.com/current', params) + api_response = api_result.json() + if riva_config['VERBOSE']: + print("[Riva Weather] Weather API Response: " + str(api_response)) + + if 'success' in api_response and api_response['success'] == False: + response['success'] = False + return + + response['success'] = True + response['country'] = api_response['location']['country'] + response['city'] = api_response['location']['name'] + response['condition'] = api_response['current']['weather_descriptions'][0] + response['temperature_c'] = api_response['current']['temperature'] + response['temperature_c_int'] = api_response['current']['temperature'] + response['humidity'] = api_response['current']['humidity'] + response['wind_mph'] = api_response['current']['wind_speed'] + response['precip'] = api_response['current']['precip'] + except: + response['success'] = False + + def query_weather_forecast(self, location, day, response): + params = { + 'access_key': self.access_key, + 'query': location + } + try: + api_result = requests.get('http://api.weatherstack.com/current', params) + api_response = api_result.json() + + if 'success' in api_response and api_response['success'] == False: + response['success'] = False + return + response['success'] = True + response['country'] = api_response['location']['country'] + response['city'] = api_response['location']['name'] + response['condition'] = api_response['current']['weather_descriptions'][0] + response['temperature_c'] = p.number_to_words(api_response['current']['temperature']) + response['temperature_c_int'] = api_response['current']['temperature'] + response['humidity'] = p.number_to_words(api_response['current']['humidity']) + response['wind_mph'] = p.number_to_words(api_response['current']['wind_speed']) + except: + response['success'] = False + + def query_weather_historical(self, location, day, response): + params = { + 'access_key': self.access_key, + 'query': location + } + try: + api_result = requests.get('http://api.weatherstack.com/current', params) + api_response = api_result.json() + + if 'success' in api_response and api_response['success'] == False: + response['success'] = False + return + + response['success'] = True + response['country'] = api_response['location']['country'] + response['city'] = api_response['location']['name'] + response['condition'] = api_response['current']['weather_descriptions'][0] + response['temperature_c'] = p.number_to_words(api_response['current']['temperature']) + response['temperature_c_int'] = api_response['current']['temperature'] + response['humidity'] = p.number_to_words(api_response['current']['humidity']) + response['wind_mph'] = p.number_to_words(api_response['current']['wind_speed']) + + except: + response['success'] = False + +def query_llm(intent, timeinfo, weather_data): + """ + This function prompts the LLM service to paraphrase real-time weather data to a natural sounding human-like response. + + Args: + intent: The intent of the user query determined by the Intent & Slot model. For ex. weather, rain, snow, temperature, humidity etc. + timeinfo: The time of the weather request. + weather_data: The response of the fulfillment service that contains real-time weather information. + + Returns: + The weather response paraphrased by the LLM service. + """ + + # Default error response + llm_response = LLM_ERROR_RESPONSE + + # Step 1: Create a connecton object + # conn = nemollm.Connection( + # host=llm_config["API_HOST"], + # access_token=llm_config["API_KEY"], + # ) + conn = nemollm.api.NemoLLM( + api_host=llm_config["API_HOST"], + api_key=llm_config["API_KEY"], + org_id=llm_config["ORG_ID"], + ) + # openai.api_key = llm_config["API_KEY"] + + # Real-time weather data is string formatted into a query + # which will be added to a few examples of paraphrasing weather data when querying the service. + # query ='\n\nIntent: {intent}\nCondition: {condition}\nPlace: {city}\nTime: {time}\nTemperature: {temperature} C\nHumidity: {humidity} percent\nWind Speed: {wind_speed} mph\n\nMisty:'.format(intent=intent, condition=weather_data["condition"], city=weather_data["city"], time=timeinfo, temperature=weather_data["temperature_c"], humidity=weather_data["humidity"], wind_speed=weather_data["wind_mph"]) + query ='\nIntent: {intent}\nCondition: {condition}\nPlace: {city}\nTime: {time}\nTemperature: {temperature} C\nHumidity: {humidity} percent\nWind Speed: {wind_speed} mph\n'.format(intent=intent, condition=weather_data["condition"], city=weather_data["city"], time=timeinfo, temperature=weather_data["temperature_c"], humidity=weather_data["humidity"], wind_speed=weather_data["wind_mph"]) + + # Add the query to the chat context + chat_context=[ + {"role": "system", "content": "You are Misty, a creative and funny weather reporter that answers questions about weather."}, + {"role": "user", "content": SAMPLE_QUERIES[0]}, + {"role": "assistant", "content": SAMPLE_RESPONSES[0]}, + {"role": "user", "content": SAMPLE_QUERIES[1]}, + {"role": "assistant", "content": SAMPLE_RESPONSES[1]}, + {"role": "user", "content": SAMPLE_QUERIES[2]}, + {"role": "assistant", "content": SAMPLE_RESPONSES[2]}, + {"role": "user", "content": query} + ] + + # If the user has asked for verbose output, print the full chat context and current query + if llm_config["VERBOSE"]: + for role_content_dictionary in chat_context: + line = role_content_dictionary['content'] + if role_content_dictionary['role'] == 'assistant': + line = 'Misty: ' + line + print(line) + + try: + # Step 2: Call the LLM service to generate a completion. + # The query with real-time weather data is added to the few-shot prompt in LLM_PROMPT_DEFAULT (refer /config.py) + # The description and ranges of various parameters are present in the API reference: (https://llm.ngc.nvidia.com/openapi/api-reference + # response = conn.generate_completion( + # model_id=llm_config["API_MODEL_NAME"], + response = conn.generate_chat( + model=llm_config["API_MODEL_NAME"], + chat_context=chat_context, + tokens_to_generate=llm_config["TOKENS_TO_GENERATE"], + logprobs=False, + temperature=llm_config["TEMPERATURE"], + top_p=llm_config["TOP_P"], + top_k=llm_config["TOP_K"], + stop=llm_config["STOP_WORDS"], + random_seed=random.randint(a=0, b=2147483647), + repetition_penalty=llm_config["REPETITION_PENALTY"], + beam_search_diversity_rate=llm_config["BEAM_SEARCH_DIVERSITY_RATE"], + beam_width=llm_config["BEAM_WIDTH"], + length_penalty=llm_config["LENGTH_PENALTY"] + ) + llm_response=response["text"] + # response = openai.ChatCompletion.create( + # model=llm_config["API_MODEL_NAME"], + # messages=LLM_PROMPT_DEFAULT+query, + # tokens_to_generate=llm_config["TOKENS_TO_GENERATE"], + # logprobs=False, + # temperature=llm_config["TEMPERATURE"], + # top_p=llm_config["TOP_P"], + # # top_k=llm_config["TOP_K"], + # stop=llm_config["STOP_WORDS"], + # random_seed=random.randint(a=0, b=2147483647), + # repetition_penalty=llm_config["REPETITION_PENALTY"], + # beam_search_diversity_rate=llm_config["BEAM_SEARCH_DIVERSITY_RATE"], + # beam_width=llm_config["BEAM_WIDTH"], + # length_penalty=llm_config["LENGTH_PENALTY"] + # ) + # llm_response=response['choices'][0]['message']['content'] + + # response = openai.ChatCompletion.create( + # model=llm_config["API_MODEL_NAME"], + # messages=[{"role": "system", "content": LLM_PROMPT_DEFAULT}, {"role": "user", "content": query}], + # temperature=llm_config["TEMPERATURE"], # this is the degree of randomness of the model's output + # top_p=llm_config["TOP_P"], + # n=1, + # stop=llm_config["STOP_WORDS"], # stop words + # max_tokens=llm_config["TOKENS_TO_GENERATE"], # tokens to generate + # presence_penalty=llm_config["PRESENCE_PENALTY"], # \in [-2.0, 2.0] Is this kind of the opposite of the beam search diversity rate? + # frequency_penalty=llm_config["REPETITION_PENALTY"], # \in [-2.0, 2.0] as opposed to [1.0, 2.0] as in nemollm repetition_penalty + # ) + # llm_response = response.choices[0].message["content"] + # if llm_response.startswith("Misty: "): + # llm_response = llm_response[7:] + + + if llm_config["VERBOSE"]: + print("Response from LLM Service:", llm_response) + except Exception as e: + print(e) + + + return llm_response diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/Weather.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/Weather.py new file mode 100644 index 0000000..ab42da4 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/Weather.py @@ -0,0 +1,54 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from riva_local.chatbot.stateDM.state import State +from riva_local.chatbot.stateDM.Util import WeatherService, query_llm, llm_weather_intents + + +DEFAULT_MESSAGE = "Unfortunately the weather service is not available at this time. Check your connection to weatherstack.com, set a different API key in your configuration or else try again later." + + +class Weather(State): + def __init__(self, bot, uid): + super(Weather, self).__init__("Weather", bot, uid) + self.next_state = None + + # # NOTE: weather forecast and weather historical are paid options in weatherstack + # # forecast and historical methods here return the current data only for now. + + def run(self, request_data): + ws = WeatherService() + + # Extract time information + if 'weatherforecastdaily' in request_data['context']['payload']: + timeinfo = request_data['context']['payload']['weatherforecastdaily'] + elif 'weathertime' in request_data['context']['payload']: + timeinfo = request_data['context']['payload']['weathertime'] + else: + timeinfo = "Today" # Default + + # Convert LLM Model intents to strings that LLM can understand + if request_data['context']['intent'] in llm_weather_intents: + response = {} + ws.query_weather(request_data['context']['location'], response) + + # Query the LLM service to paraphrase the weather-data to a natural sounding response + if response['success']: + message = query_llm(intent=llm_weather_intents[request_data['context']['intent']], + timeinfo=timeinfo, + weather_data=response) + else: + message = DEFAULT_MESSAGE + else: + # TODO: Add support for small talk + message = "Sorry, I did not understand the query." + + request_data['context'].update({'weather_status': message}) + + # Update the response text with the weather status + request_data.update({'response': + self.construct_message(request_data, message)}) \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/__init__.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state.py new file mode 100644 index 0000000..20abc2b --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state.py @@ -0,0 +1,44 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from abc import ABC, abstractmethod +import sys + +class State(ABC): + """ State is an abstract class """ + + def __init__(self, name, bot, uid): + self.name = name # Name of the state + self.bot = bot # Name of the chatbot eg. "rivaWeather" + self.uid = uid + self.next_state = None + + @abstractmethod + def run(self, request_data): + assert 0, "Run not implemented!" + + def next(self): + # This should only be run after populating next_state + return self.next_state + + def construct_message(self, request_data, text): + """ Constructs the response frame, + appending to a prev response if that exists """ + message = {'type': 'text', + 'payload': {'text': text}, + 'delay': 0} + + prev_response = request_data.get('response', False) + + # If there was an old response, append the new response to the list + if prev_response: + prev_response.append(message) + # Else create a list containing the response + else: + prev_response = [message] + + return prev_response \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state_data.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state_data.py new file mode 100644 index 0000000..7de7a9b --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state_data.py @@ -0,0 +1,36 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +# This is used for finding the state to transition to based on intent +# We've added the misspelled intent weather.temprature because that intent is +# misspelled in /models/riva_intent_weather/1/intent_labels.csv +# To clarify further, the problem is in the outputs of the intent slot model, +# not in the sample apps or the Riva Client Python module +intent_transitions = { + 'rivaWeather': { + 'weather.qa_answer': 'checkWeatherLocation', + 'weather.weather': 'checkWeatherLocation', + 'context.weather': 'checkWeatherLocation', + 'weather.temperature': 'checkWeatherLocation', + 'weather.temprature': 'checkWeatherLocation', # Intentional misspelling for debugging + 'weather.sunny': 'checkWeatherLocation', + 'weather.cloudy': 'checkWeatherLocation', + 'weather.snow': 'checkWeatherLocation', + 'weather.rainfall': 'checkWeatherLocation', + 'weather.snow_yes_no': 'checkWeatherLocation', + 'weather.rainfall_yes_no': 'checkWeatherLocation', + 'weather.temperature_yes_no': 'checkWeatherLocation', + 'weather.humidity': 'checkWeatherLocation', + 'weather.humidity_yes_no': 'checkWeatherLocation', + 'navigation.startnavigationpoi': 'checkWeatherLocation', + 'navigation.geteta': 'checkWeatherLocation', + 'navigation.showdirection': 'checkWeatherLocation', + 'riva_error': 'error', + 'navigation.showmappoi': 'error', + 'nomatch.none': 'error' + } +} \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state_machine.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state_machine.py new file mode 100644 index 0000000..420c1e2 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/state_machine.py @@ -0,0 +1,59 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import copy +from riva_local.chatbot.stateDM.states import userInput, userLocationInput +from config import riva_config + +verbose = riva_config["VERBOSE"] + +############################################################################### +# stateDM (Simple Dialog Manager): A Finite State Machine +############################################################################### +class StateMachine: + def __init__(self, user_conversation_index, init): + self.uid = user_conversation_index + self.bot = "rivaWeather" + if verbose: + print("[stateDM] Initializing the state machine for uid: ", self.uid) + self.currentState = init(self.bot, self.uid) + + def execute_state(self, bot, context, text): + # Fresh request frame + request_data = {'context': context, + 'text': text, + 'uid': self.uid, + 'payload': {}} + + # TODO: Add support for !undo (saving previous context) and !reset + + # Keep executing the state machine until a user input is required + # i.e. stop when state is either InputUser or InputContext + while True: + # Run the current state + if verbose: + print("[stateDM] Executing state:", + self.currentState.name) + self.currentState.run(request_data) + nextState = self.currentState.next() + + # If the next state exists + if nextState is not None: + # Create an object from the next state + self.currentState = nextState(self.bot, self.uid) + # If the next state requires user input, just return + # WARNING: Can go into infinite loop if states don't have + # next_state configured properly + if nextState == userInput or nextState == userLocationInput: + return request_data + + # If no next state exists, wait for user input now + else: + if verbose: + print("[stateDM] No next state, waiting for user input") + self.currentState = userInput(self.bot, self.uid) + return request_data diff --git a/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/states.py b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/states.py new file mode 100644 index 0000000..c9c9709 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/chatbot/stateDM/states.py @@ -0,0 +1,139 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from riva_local.chatbot.stateDM.state import State +from riva_local.nlp.nlp import get_entities +from riva_local.chatbot.stateDM.state_data import intent_transitions +from riva_local.chatbot.stateDM.Weather import Weather +import sys +from config import riva_config + +verbose = riva_config["VERBOSE"] + + +class initialState(State): + def __init__(self, bot, uid): + super(initialState, self).__init__("initialState", bot, uid) + + def run(self, request_data): + text = "Hi, welcome to Misty's weather service. How may I help you?" + + # Update response with welcome text + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userInput + + +class userInput(State): + def __init__(self, bot, uid): + super(userInput, self).__init__("userInput", bot, uid) + self.next_state = None + + def get_state(self, class_str, default): + return getattr(sys.modules[__name__], class_str, default) + + def run(self, request_data): + # Get response from Riva NLU + response = get_entities(request_data['text'], "riva") + response_intent = response.get('intent', False) + + # Fetch the transitions dict for the bot + intents_project = intent_transitions[self.bot] + + # If a valid intent was detected + if response_intent: + # If a valid state exists for the response intent AND + # the response intent is different from the one already in context + if intents_project.get(response_intent, False) and \ + response_intent != request_data['context'].get('intent', False): + self.next_state = self.get_state(intents_project.get(response_intent, False), None) + + # update request_data with response and next_state + request_data['context'].update(response) + return + + # If intent exists in the context, use that + if 'intent' in request_data['context']: + # Populate context with response (eg. new entity location value), except the intent + request_data['context'].update({x: response[x] for x in response if x not in 'intent'}) + self.next_state = self.get_state(intents_project.get(request_data['context']['intent'], False), None) + return + + +class userLocationInput(State): + def __init__(self, bot, uid): + super(userLocationInput, self).__init__("userLocationInput", bot, uid) + + def run(self, request_data): + response = get_entities(request_data['text'], "riva") + + # Updates all keys except intent + request_data['context'].update( + {x: response[x] for x in response if x not in 'intent'}) + + # Check if the required entities (location here) are present + # If present, proceed to Weather State + if 'location' in response: + # Move to Weather State + self.next_state = Weather + else: + # Else, proceed to ErrorState + self.next_state = error + + +class checkWeatherLocation(State): + def __init__(self, bot, uid): + super(checkWeatherLocation, self).__init__( + "checkWeatherLocation", bot, uid) + + def run(self, request_data): + # Check if all entities (location) required for informing weather exists + location = request_data['context'].get("location", False) + + if location: + # If location exists, then call Weather class to check the weather location + self.next_state = Weather + else: + # If not, then asks location and moves to userLocationInput to fetch it + text = "For which location?" + + # Update response asking the user location, intent stays the same + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userLocationInput + + +class error(State): + def __init__(self, bot, uid): + super(error, self).__init__("error", bot, uid) + + def run(self, request_data): + text = "Sorry, I couldn't get you!" + + # Update response with error text + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userInput + + +# TODO: This state is not in use currently, +# add this if end of conversation is required +class end(State): + def __init__(self, bot, uid): + super(end, self).__init__("end", bot, uid) + + def run(self, request_data): + text = "Bye!" + + # Update response with end state text + request_data.update({'response': + self.construct_message(request_data, text)}) + + self.next_state = userInput diff --git a/virtual-assistant-nemo-llm/riva_local/nlp/__init__.py b/virtual-assistant-nemo-llm/riva_local/nlp/__init__.py new file mode 100644 index 0000000..d90b5a1 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/nlp/__init__.py @@ -0,0 +1,8 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +from .nlp import * \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/nlp/nlp.py b/virtual-assistant-nemo-llm/riva_local/nlp/nlp.py new file mode 100644 index 0000000..8d59277 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/nlp/nlp.py @@ -0,0 +1,218 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import riva.client +from riva.client.proto.riva_nlp_pb2 import ( + AnalyzeIntentResponse, + NaturalQueryResponse, + TokenClassResponse +) + +import grpc +from config import riva_config, nlp_config +import requests +import json + +# QA api-endpoint +QA_API_ENDPOINT = nlp_config["QA_API_ENDPOINT"] +enable_qa = riva_config["ENABLE_QA"] +verbose = riva_config["VERBOSE"] + +auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) +riva_nlp = riva.client.NLPService(auth) + + +def get_qa_answer(context, question, p_threshold): + # if hasattr(resp, 'intent'): + # entities['intent'] = resp.intent.class_name + + # data to be sent to api + data = { + "question": question, + "context": context + } + # sending post request and saving response as response object + r = requests.post(QA_API_ENDPOINT, json=data) + + # extracting response text + qa_resp = json.loads(r.text) + # print("The response from QA server is :%s"%qa_response) + + if verbose: + print("[Riva NLU] The answer is :%s" % qa_resp['result']) + print("[Riva NLU] The probability is :%s" % qa_resp['p']) + + if qa_resp['result'] == '': + print("[Riva NLU] QA returned empty string.") + + if qa_resp['p'] < p_threshold: + print("[Riva NLU] QA response lower than threshold - ", p_threshold) + # qa_resp['result'] = "I am not too sure about what you meant. " + qa_resp['result'] + # return qa_resp + + return qa_resp + + +if enable_qa == "true": + # test question and passage to be sent to api + riva_test = "I am Riva. I can talk about the weather. My favorite season is spring. I know the weather info " \ + "from Weatherstack api. I have studied the weather all my life." + test_question = "What is your name?" + p_threshold = 0.4 + get_qa_answer(riva_test, test_question, p_threshold) + + +def get_intent(resp, entities): + if hasattr(resp, 'intent'): + entities['intent'] = resp.intent.class_name + + +def get_slots(resp, entities): + entities['payload'] = dict() + all_entities_class = {} + all_entities = [] + if hasattr(resp, 'slots'): + for i in range(len(resp.slots)): + slot_class = resp.slots[i].label[0].class_name.replace("\r", "") + token = resp.slots[i].token.replace("?", "").replace(",", "").replace(".", "").replace("[SEP]", "").strip() + score = resp.slots[i].label[0].score + if slot_class and token: + if slot_class == 'weatherplace' or slot_class == 'destinationplace': + entity = { "value": token, + "confidence": score, + "entity": "location" } + else: + entity = { "value": token, + "confidence": score, + "entity": slot_class } + all_entities_class[entity["entity"]] = 1 + all_entities.append(entity) + for cl in all_entities_class: + partial_entities = list(filter(lambda x: x["entity"] == cl, all_entities)) + partial_entities.sort(reverse=True, key=lambda x: x["confidence"]) + for entity in partial_entities: + if cl == "location": + entities['location'] = entity["value"] + else: + entities['payload'][cl] = entity["value"] + break + + +def get_riva_output(text): + # Submit an AnalyzeIntent request. We do not provide a domain with the query, so a domain + # classifier is run first, and based on the inferred value from the domain classifier, + # the query is run through the appropriate intent/slot classifier + # Note: the detected domain is also returned in the response. + try: + # The is appended to "riva_intent_" to look for a model "riva_intent_" + # So the model "riva_intent_" needs to be preloaded in riva server. + # In this case the domain is weather and the model being used is "riva_intent_weather-misc". + options = riva.client.AnalyzeIntentOptions(lang='en-US', domain='weather') + + resp: AnalyzeIntentResponse = riva_nlp.analyze_intent(text, options) + + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request") + return {'riva_error': 'riva_error'} + entities = {} + get_intent(resp, entities) + get_slots(resp, entities) + if 'location' not in entities: + if verbose: + print(f"[Riva NLU] Did not find any location in the string: {text}\n" + "[Riva NLU] Checking again using NER model") + try: + model_name = "riva_ner" + resp_ner: TokenClassResponse = riva_nlp.classify_tokens(text, model_name) + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request (riva_ner)") + return {'riva_error': 'riva_error'} + + if verbose: + print(f"[Riva NLU] NER response results: \n {resp_ner.results[0].results}\n") + print("[Riva NLU] Location Entities:") + loc_count = 0 + for result in resp_ner.results[0].results: + if result.label[0].class_name == "LOC": + if verbose: + print(f"[Riva NLU] Location found: {result.token}") # Flow unhandled for multiple location input + loc_count += 1 + entities['location'] = result.token + if loc_count == 0: + if verbose: + print("[Riva NLU] No location found in string using NER LOC") + print("[Riva NLU] Checking response domain") + if resp.domain.class_name == "nomatch.none": + # as a final resort try QA API + if enable_qa == "true": + if verbose: + print("[Riva NLU] Checking using QA API") + riva_misty_profile = requests.get(nlp_config["RIVA_MISTY_PROFILE"]).text # Live pull from Cloud + qa_resp = get_qa_answer(riva_misty_profile, text, p_threshold) + if not qa_resp['result'] == '': + if verbose: + print("[Riva NLU] received qa result") + entities['intent'] = 'qa_answer' + entities['answer_span'] = qa_resp['result'] + entities['query'] = text + else: + entities['intent'] = 'riva_error' + else: + entities['intent'] = 'riva_error' + if verbose: + print("[Riva NLU] This is what entities contain: ", entities) + return entities + + +def get_riva_output_qa_only(text): + # Submit an AnalyzeIntentRequest. We do not provide a domain with the query, so a domain + # classifier is run first, and based on the inferred value from the domain classifier, + # the query is run through the appropriate intent/slot classifier + # Note: the detected domain is also returned in the response. + + entities = {} + try: + if enable_qa == "true": + if verbose: + print("[Riva NLU] Checking using QA API") + riva_mark_KB = requests.get(nlp_config["RIVA_MARK_KB"]).text # Live pull from Cloud + qa_resp = get_qa_answer(riva_mark_KB, text, p_threshold) + if not qa_resp['result'] == '': + if verbose: + print("[Riva NLU] received qa result") + entities['intent'] = 'qa_answer' + entities['answer_span'] = qa_resp['result'] + entities['query'] = text + else: + entities['intent'] = 'riva_error' + else: + entities['intent'] = 'riva_error' + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request") + return {'riva_error': 'riva_error'} + if verbose: + print("[Riva NLU] This is what entities contain: ", entities) + return entities + + +def get_entities(text, nlp_type): + if nlp_type is None: + nlp_type = "empty" + + ent_out = {} + if nlp_type == "empty": + ent_out.update({'raw_text': str(text)}) + elif nlp_type == "riva": + riva_out = get_riva_output(text) + ent_out.update(riva_out) + elif nlp_type == "riva_mark": + riva_out = get_riva_output_qa_only(text) + ent_out.update(riva_out) + return ent_out diff --git a/virtual-assistant-nemo-llm/riva_local/nlp/test_nlp.py b/virtual-assistant-nemo-llm/riva_local/nlp/test_nlp.py new file mode 100644 index 0000000..a977585 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/nlp/test_nlp.py @@ -0,0 +1,119 @@ +import riva.client +# from config import riva_config, nlp_config +from pprint import pprint + +# QA api-endpoint +# QA_API_ENDPOINT = nlp_config["QA_API_ENDPOINT"] +# enable_qa = riva_config["ENABLE_QA"] +# verbose = riva_config["VERBOSE"] + +# auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) +auth = riva.client.Auth(uri='localhost:50051') +riva_nlp = riva.client.NLPService(auth) + +def get_intent(resp, entities): + if hasattr(resp, 'intent'): + entities['intent'] = resp.intent.class_name + + +def get_slots(resp, entities): + entities['payload'] = dict() + all_entities_class = {} + all_entities = [] + if hasattr(resp, 'slots'): + for i in range(len(resp.slots)): + slot_class = resp.slots[i].label[0].class_name.replace("\r", "") + token = resp.slots[i].token.replace("?", "").replace(",", "").replace(".", "").replace("[SEP]", "").strip() + score = resp.slots[i].label[0].score + if slot_class and token: + if slot_class == 'weatherplace' or slot_class == 'destinationplace': + entity = { "value": token, + "confidence": score, + "entity": "location" } + else: + entity = { "value": token, + "confidence": score, + "entity": slot_class } + all_entities_class[entity["entity"]] = 1 + all_entities.append(entity) + for cl in all_entities_class: + partial_entities = list(filter(lambda x: x["entity"] == cl, all_entities)) + partial_entities.sort(reverse=True, key=lambda x: x["confidence"]) + for entity in partial_entities: + if cl == "location": + entities['location'] = entity["value"] + else: + entities['payload'][cl] = entity["value"] + break + + +def get_riva_output(text, verbose=True, enable_qa=False): + # Submit an AnalyzeIntent request. We do not provide a domain with the query, so a domain + # classifier is run first, and based on the inferred value from the domain classifier, + # the query is run through the appropriate intent/slot classifier + # Note: the detected domain is also returned in the response. + try: + # The is appended to "riva_intent_" to look for a model "riva_intent_" + # So the model "riva_intent_" needs to be preloaded in riva server. + # In this case the domain is weather and the model being used is "riva_intent_weather-misc". + options = riva.client.AnalyzeIntentOptions(lang='en-US', domain='weather') + + resp: AnalyzeIntentResponse = riva_nlp.analyze_intent(text, options) + + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request") + return {'riva_error': 'riva_error'} + entities = {} + get_intent(resp, entities) + get_slots(resp, entities) + if 'location' not in entities: + if verbose: + print(f"[Riva NLU] Did not find any location in the string: {text}\n" + "[Riva NLU] Checking again using NER model") + try: + model_name = "riva_ner" + resp_ner: TokenClassResponse = riva_nlp.classify_tokens(text, model_name) + except Exception as inst: + # An exception occurred + print("[Riva NLU] Error during NLU request (riva_ner)") + return {'riva_error': 'riva_error'} + + if verbose: + print(f"[Riva NLU] NER response results: \n {resp_ner.results[0].results}\n") + print("[Riva NLU] Location Entities:") + loc_count = 0 + for result in resp_ner.results[0].results: + if result.label[0].class_name == "LOC": + if verbose: + print(f"[Riva NLU] Location found: {result.token}") # Flow unhandled for multiple location input + loc_count += 1 + entities['location'] = result.token + if loc_count == 0: + if verbose: + print("[Riva NLU] No location found in string using NER LOC") + print("[Riva NLU] Checking response domain") + if resp.domain.class_name == "nomatch.none": + # as a final resort try QA API + if enable_qa == "true": + if verbose: + print("[Riva NLU] Checking using QA API") + riva_misty_profile = requests.get(nlp_config["RIVA_MISTY_PROFILE"]).text # Live pull from Cloud + qa_resp = get_qa_answer(riva_misty_profile, text, p_threshold) + if not qa_resp['result'] == '': + if verbose: + print("[Riva NLU] received qa result") + entities['intent'] = 'qa_answer' + entities['answer_span'] = qa_resp['result'] + entities['query'] = text + else: + entities['intent'] = 'riva_error' + else: + entities['intent'] = 'riva_error' + if verbose: + print("[Riva NLU] This is what entities contain: ", entities) + return entities + +text = "What is the chance of rain at 4:00 pm on Saturday in Berkeley?" +entities = get_riva_output(text) +# pprint(entities) diff --git a/virtual-assistant-nemo-llm/riva_local/tts/__init__.py b/virtual-assistant-nemo-llm/riva_local/tts/__init__.py new file mode 100644 index 0000000..a54072c --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/tts/__init__.py @@ -0,0 +1,2 @@ +from .tts import * +from .tts_stream import * \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/tts/tts.py b/virtual-assistant-nemo-llm/riva_local/tts/tts.py new file mode 100644 index 0000000..f4536e3 --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/tts/tts.py @@ -0,0 +1,114 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import grpc +import riva.client +from six.moves import queue +from config import riva_config, tts_config +import numpy as np +import time + +# Default ASR parameters - Used in case config values not specified in the config.py file +VERBOSE = False +SAMPLE_RATE = 22050 +LANGUAGE_CODE = "en-US" +VOICE_NAME = "English-US.Female-1" + +class TTSPipe(object): + """Opens a gRPC channel to Riva TTS to synthesize speech + from text in batch mode.""" + + def __init__(self): + self.verbose = tts_config["VERBOSE"] if "VERBOSE" in tts_config else VERBOSE + self.sample_rate = tts_config["SAMPLE_RATE"] if "SAMPLE_RATE" in tts_config else SAMPLE_RATE + self.language_code = tts_config["LANGUAGE_CODE"] if "LANGUAGE_CODE" in tts_config else LANGUAGE_CODE + self.voice_name = tts_config["VOICE_NAME"] if "VOICE_NAME" in tts_config else VOICE_NAME + self.audio_encoding = ra.AudioEncoding.LINEAR_PCM + self._buff = queue.Queue() + self.closed = False + self._flusher = bytes(np.zeros(dtype=np.int16, shape=(self.sample_rate, 1))) # Silence audio + self.current_tts_duration = 0 + + def start(self): + if self.verbose: + print('[Riva TTS] Creating Stream TTS channel: {}'.format(riva_config["RIVA_SPEECH_API_URL"])) + self.auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) + self.riva_tts = riva.client.SpeechSynthesisService(self.auth) + + def reset_current_tts_duration(self): + self.current_tts_duration = 0 + + def get_current_tts_duration(self): + return self.current_tts_duration + + def fill_buffer(self, in_data): + """To collect text responses from the state machine output, into a buffer.""" + if len(in_data): + self._buff.put(in_data) + + def close(self): + self.closed = True + self._buff.queue.clear() + self._buff.put(None) # means the end + del(self.channel) + + def get_speech(self): + """Returns speech audio from text responses in the buffer""" + self.start() + wav_header = self.gen_wav_header(self.sample_rate, 16, 1, 0) + yield bytes(wav_header) + flush_count = 0 + while not self.closed: + if not self._buff.empty(): # Enter if queue/buffer is not empty. + try: + text = self._buff.get(block=False, timeout=0) + if self.verbose: + print('[Riva TTS] Pronounced Text: ', text) + responses = self.riva_tts.synthesize( + text = text, + language_code = self.language_code, + encoding = riva.client.AudioEncoding.LINEAR_PCM, + sample_rate_hz = self.sample_rate, + voice_name = self.voice_name + ) + datalen = len(resp.audio) // 2 + data16 = np.ndarray(buffer=resp.audio, dtype=np.int16, shape=(datalen, 1)) + speech = bytes(data16.data) + duration = len(data16) * 2 / (self.sample_rate * 1 * 16 / 8) + if self.verbose: + print(f'[Riva TTS] The datalen is: {datalen}') + print(f'[Riva TTS] Duration of audio is: {duration}') + self.current_tts_duration = duration + yield speech + flush_count = 5 + continue + except Exception as e: + print('[Riva TTS] ERROR:') + print(str(e)) + + # To flush out remaining audio from client buffer + if flush_count > 0: + yield self._flusher + flush_count -= 1 + continue + time.sleep(0.1) # Set the buffer check rate. + + def gen_wav_header(self, sample_rate, bits_per_sample, channels, datasize): + o = bytes("RIFF", 'ascii') # (4byte) Marks file as RIFF + o += (datasize + 36).to_bytes(4, 'little') # (4byte) File size in bytes excluding this and RIFF marker + o += bytes("WAVE", 'ascii') # (4byte) File type + o += bytes("fmt ", 'ascii') # (4byte) Format Chunk Marker + o += (16).to_bytes(4, 'little') # (4byte) Length of above format data + o += (1).to_bytes(2, 'little') # (2byte) Format type (1 - PCM) + o += channels.to_bytes(2, 'little') # (2byte) + o += sample_rate.to_bytes(4, 'little') # (4byte) + o += (sample_rate * channels * bits_per_sample // 8).to_bytes(4, 'little') # (4byte) + o += (channels * bits_per_sample // 8).to_bytes(2, 'little') # (2byte) + o += bits_per_sample.to_bytes(2, 'little') # (2byte) + o += bytes("data", 'ascii') # (4byte) Data Chunk Marker + o += datasize.to_bytes(4, 'little') # (4byte) Data size in bytes + return o \ No newline at end of file diff --git a/virtual-assistant-nemo-llm/riva_local/tts/tts_stream.py b/virtual-assistant-nemo-llm/riva_local/tts/tts_stream.py new file mode 100644 index 0000000..e169c1c --- /dev/null +++ b/virtual-assistant-nemo-llm/riva_local/tts/tts_stream.py @@ -0,0 +1,123 @@ +# ============================================================================== +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# The License information can be found under the "License" section of the +# README.md file. +# ============================================================================== + +import grpc +import riva.client + +from six.moves import queue +from config import riva_config, tts_config +import numpy as np +import time + +# Default ASR parameters - Used in case config values not specified in the config.py file +VERBOSE = False +SAMPLE_RATE = 22050 +LANGUAGE_CODE = "en-US" +VOICE_NAME = "English-US.Female-1" + +class TTSPipe(object): + """Opens a gRPC channel to Riva TTS to synthesize speech + from text in streaming mode.""" + + def __init__(self): + self.verbose = tts_config["VERBOSE"] if "VERBOSE" in tts_config else VERBOSE + self.sample_rate = tts_config["SAMPLE_RATE"] if "SAMPLE_RATE" in tts_config else SAMPLE_RATE + self.language_code = tts_config["LANGUAGE_CODE"] if "LANGUAGE_CODE" in tts_config else LANGUAGE_CODE + self.voice_name = tts_config["VOICE_NAME"] if "VOICE_NAME" in tts_config else VOICE_NAME + self.audio_encoding = riva.client.AudioEncoding.LINEAR_PCM + self._buff = queue.Queue() + self.closed = False + self._flusher = bytes(np.zeros(dtype=np.int16, shape=(self.sample_rate, 1))) # Silence audio + self.current_tts_duration = 0 + + def start(self): + if self.verbose: + print('[Riva TTS] Creating Stream TTS channel: {}'.format(riva_config["RIVA_SPEECH_API_URL"])) + self.auth = riva.client.Auth(uri=riva_config["RIVA_SPEECH_API_URL"]) + self.riva_tts = riva.client.SpeechSynthesisService(self.auth) + + def reset_current_tts_duration(self): + self.current_tts_duration = 0 + + def get_current_tts_duration(self): + return self.current_tts_duration + + def fill_buffer(self, in_data): + """To collect text responses from the state machine output, into a buffer.""" + if len(in_data): + self._buff.put(in_data) + + def close(self): + self.closed = True + self._buff.queue.clear() + self._buff.put(None) # means the end + del(self.channel) + + def get_speech(self): + """Returns speech audio from text responses in the buffer""" + self.start() + wav_header = self.gen_wav_header(self.sample_rate, 16, 1, 0) + yield bytes(wav_header) + flush_count = 0 + while not self.closed: + if not self._buff.empty(): # Enter if queue/buffer is not empty. + try: + text = self._buff.get(block=False, timeout=0) + if self.verbose: + print('[Riva TTS] Pronounced Text: ', text) + + if self.verbose: + print('[Riva TTS] Starting TTS streaming') + duration = 0 + self.current_tts_duration = 0 + + # <---------- EXERCISE: Fill-in the line of code below -----------> + # responses = self.self.riva_tts.synthesize(xx) ? + responses = self.riva_tts.synthesize_online( + text = text, + language_code = self.language_code, + encoding = riva.client.AudioEncoding.LINEAR_PCM, + sample_rate_hz = self.sample_rate, + voice_name = self.voice_name + ) + + for resp in responses: + datalen = len(resp.audio) // 2 + data16 = np.ndarray(buffer=resp.audio, dtype=np.int16, shape=(datalen, 1)) + speech = bytes(data16.data) + duration += len(data16) * 2 / (self.sample_rate * 1 * 16 / 8) + self.current_tts_duration += duration + if self.verbose: + print(f'[Riva TTS] Duration of audio is: {duration}') + yield speech + except Exception as e: + print('[Riva TTS] ERROR:') + print(str(e)) + flush_count = 5 + continue + # To flush out remaining audio from client buffer + if flush_count > 0: + yield self._flusher + flush_count -= 1 + continue + time.sleep(0.1) # Set the buffer check rate. + + def gen_wav_header(self, sample_rate, bits_per_sample, channels, datasize): + o = bytes("RIFF", 'ascii') # (4byte) Marks file as RIFF + o += (datasize + 36).to_bytes(4, 'little') # (4byte) File size in bytes excluding this and RIFF marker + o += bytes("WAVE", 'ascii') # (4byte) File type + o += bytes("fmt ", 'ascii') # (4byte) Format Chunk Marker + o += (16).to_bytes(4, 'little') # (4byte) Length of above format data + o += (1).to_bytes(2, 'little') # (2byte) Format type (1 - PCM) + o += channels.to_bytes(2, 'little') # (2byte) + o += sample_rate.to_bytes(4, 'little') # (4byte) + o += (sample_rate * channels * bits_per_sample // 8).to_bytes(4, 'little') # (4byte) + o += (channels * bits_per_sample // 8).to_bytes(2, 'little') # (2byte) + o += bits_per_sample.to_bytes(2, 'little') # (2byte) + o += bytes("data", 'ascii') # (4byte) Data Chunk Marker + o += datasize.to_bytes(4, 'little') # (4byte) Data size in bytes + return o