diff --git a/cogs/messagehandler.py b/cogs/messagehandler.py index 2323dbd..ffc2f03 100644 --- a/cogs/messagehandler.py +++ b/cogs/messagehandler.py @@ -249,17 +249,18 @@ async def handle_text_message(self, message, mode=""): message, ) await self.add_message_to_dict(message, message.clean_content) - async with message.channel.typing(): - # If the response is more than 2000 characters, split it - chunks = [response[i : i + 1998] for i in range(0, len(response), 1998)] - for chunk in chunks: - print(chunk) - response_obj = await message.channel.send(chunk) - await self.add_message_to_dict( - response_obj, response_obj.clean_content - ) - # self.bot.sent_last_message[str(message.channel.id)] = True - # await log_message(response_obj) + if response: + async with message.channel.typing(): + # If the response is more than 2000 characters, split it + chunks = [response[i : i + 1998] for i in range(0, len(response), 1998)] + for chunk in chunks: + print(chunk) + response_obj = await message.channel.send(chunk) + await self.add_message_to_dict( + response_obj, response_obj.clean_content + ) + # self.bot.sent_last_message[str(message.channel.id)] = True + # await log_message(response_obj) async def set_listen_only_mode_timer(self, channel_id): # Start the timer diff --git a/cogs/pygbot.py b/cogs/pygbot.py index 03a69eb..df96151 100644 --- a/cogs/pygbot.py +++ b/cogs/pygbot.py @@ -9,6 +9,7 @@ from discord import app_commands from discord.ext import commands import os +import asyncio # load environment STOP_SEQUENCES variables and split them into a list by comma @@ -127,7 +128,7 @@ async def get_memory_for_channel(self, channel_id): name = message[0] channel_ids = str(message[1]) message = message[2] - print(f"{name}: {message}") + #print(f"{name}: {message}") await self.add_history(name, channel_ids, message) # self.memory = self.histories[channel_id] @@ -160,7 +161,7 @@ async def generate_response(self, message, message_content) -> None: name = message.author.display_name memory = await self.get_memory_for_channel(str(channel_id)) stop_sequence = await self.get_stop_sequence_for_channel(channel_id, name) - print(f"stop sequences: {stop_sequence}") + #print(f"stop sequences: {stop_sequence}") formatted_message = f"{name}: {message_content}" MAIN_TEMPLATE = f""" {self.top_character_info} @@ -180,7 +181,13 @@ async def generate_response(self, message, message_content) -> None: memory=memory, ) input_dict = {"input": formatted_message, "stop": stop_sequence} - response_text = conversation(input_dict) + + # Run the conversation chain + if self.bot.koboldcpp_version >= 1.29: + response_text = await conversation.acall(input_dict,channel_id) + else: + response_text = await conversation.acall(input_dict) + response = await self.detect_and_replace_out(response_text["response"]) with open(self.convo_filename, "a", encoding="utf-8") as f: f.write(f"{message.author.display_name}: {message_content}\n") @@ -199,7 +206,7 @@ async def add_history(self, name, channel_id, message_content) -> None: formatted_message = f"{name}: {message_content}" # add the message to the memory - print(f"adding message to memory: {formatted_message}") + #print(f"adding message to memory: {formatted_message}") memory.add_input_only(formatted_message) return None @@ -210,6 +217,10 @@ def __init__(self, bot): self.chatlog_dir = bot.chatlog_dir self.chatbot = Chatbot(bot) + # Store current task and last message here + self.current_tasks = {} + self.last_messages = {} + # create chatlog directory if it doesn't exist if not os.path.exists(self.chatlog_dir): os.makedirs(self.chatlog_dir) @@ -233,8 +244,32 @@ async def chat_command(self, name, channel_id, message_content, message) -> None and self.chatbot.convo_filename != chatlog_filename ): await self.chatbot.set_convo_filename(chatlog_filename) - response = await self.chatbot.generate_response(message, message_content) - return response + + # Check if the task is still running by channel ID + #print(f"The current task is: {self.current_tasks[channel_id]}") # for debugging purposes + if channel_id in self.current_tasks: + task = self.current_tasks[channel_id] + + if task is not None and not task.done(): + # Cancelling previous task, add last message to the history + await self.chatbot.add_history(name, str(channel_id), self.last_messages[channel_id]) + + # If the endpoint is koboldcpp, stop the generation by channel ID + if self.bot.koboldcpp_version >= 1.29: + await self.bot.llm._stop(channel_id) + + self.current_task.cancel() + + # Create a new task and last message bounded to the channel ID + self.last_messages[channel_id] = message_content + self.current_tasks[channel_id] = asyncio.create_task(self.chatbot.generate_response(message, message_content)) + + try: + response = await self.current_tasks[channel_id] + return response + except asyncio.CancelledError: + print(f"Cancelled {self.chatbot.char_name}'s current response, regenerate another reply...") + return None # No Response Handler @commands.command(name="chatnr") diff --git a/discordbot.py b/discordbot.py index d0d3cd3..990eea7 100644 --- a/discordbot.py +++ b/discordbot.py @@ -6,7 +6,8 @@ from pathlib import Path import base64 from helpers.textgen import TextGen -from langchain.llms import KoboldApiLLM, OpenAI +from helpers.koboldai import KoboldApiLLM +from langchain.llms import OpenAI from discord import app_commands from discord.ext import commands from discord.ext.commands import Bot @@ -237,6 +238,13 @@ async def on_ready(): "\n\n\n\nERROR: Unable to retrieve channel from .env \nPlease make sure you're using a valid channel ID, not a server ID." ) + # Check if the endpoint is connected to koboldcpp + if bot.llm._llm_type == "koboldai": + bot.koboldcpp_version = bot.llm.check_version() + print(f"KoboldCPP Version: {bot.koboldcpp_version}") + else: + bot.koboldcpp_version = 0.0 + # COG LOADER async def load_cogs() -> None: diff --git a/helpers/koboldai.py b/helpers/koboldai.py new file mode 100644 index 0000000..9248279 --- /dev/null +++ b/helpers/koboldai.py @@ -0,0 +1,327 @@ +import logging +from typing import Any, Dict, List, Optional + +import requests +import asyncio +import aiohttp + +import random +import string + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.llms.base import LLM + +logger = logging.getLogger(__name__) + +def clean_url(url: str) -> str: + """Remove trailing slash and /api from url if present.""" + if url.endswith("/api"): + return url[:-4] + elif url.endswith("/"): + return url[:-1] + else: + return url + + +class KoboldApiLLM(LLM): + """Kobold API language model. + + It includes several fields that can be used to control the text generation process. + + To use this class, instantiate it with the required parameters and call it with a + prompt to generate text. For example: + + kobold = KoboldApiLLM(endpoint="http://localhost:5000") + result = kobold("Write a story about a dragon.") + + This will send a POST request to the Kobold API with the provided prompt and + generate text. + """ + + endpoint: str + """The API endpoint to use for generating text.""" + + use_story: Optional[bool] = False + """ Whether or not to use the story from the KoboldAI GUI when generating text. """ + + use_authors_note: Optional[bool] = False + """Whether to use the author's note from the KoboldAI GUI when generating text. + + This has no effect unless use_story is also enabled. + """ + + use_world_info: Optional[bool] = False + """Whether to use the world info from the KoboldAI GUI when generating text.""" + + use_memory: Optional[bool] = False + """Whether to use the memory from the KoboldAI GUI when generating text.""" + + max_context_length: Optional[int] = 1600 + """Maximum number of tokens to send to the model. + + minimum: 1 + """ + + max_length: Optional[int] = 80 + """Number of tokens to generate. + + maximum: 512 + minimum: 1 + """ + + rep_pen: Optional[float] = 1.12 + """Base repetition penalty value. + + minimum: 1 + """ + + rep_pen_range: Optional[int] = 1024 + """Repetition penalty range. + + minimum: 0 + """ + + rep_pen_slope: Optional[float] = 0.9 + """Repetition penalty slope. + + minimum: 0 + """ + + temperature: Optional[float] = 0.6 + """Temperature value. + + exclusiveMinimum: 0 + """ + + tfs: Optional[float] = 0.9 + """Tail free sampling value. + + maximum: 1 + minimum: 0 + """ + + top_a: Optional[float] = 0.9 + """Top-a sampling value. + + minimum: 0 + """ + + top_p: Optional[float] = 0.95 + """Top-p sampling value. + + maximum: 1 + minimum: 0 + """ + + top_k: Optional[int] = 0 + """Top-k sampling value. + + minimum: 0 + """ + + typical: Optional[float] = 0.5 + """Typical sampling value. + + maximum: 1 + minimum: 0 + """ + + # To store genkeys for each generation + genkeys = {} + is_koboldcpp = False + + @property + def _llm_type(self) -> str: + return "koboldai" + + # Define a helper method to generate the data dict + def _get_parameters( + self, + prompt: str, + stop: Optional[List[str]] = None) -> Dict[str, Any]: + """Get the parameters to send to the API.""" + data: Dict[str, Any] = { + "prompt": prompt, + "use_story": self.use_story, + "use_authors_note": self.use_authors_note, + "use_world_info": self.use_world_info, + "use_memory": self.use_memory, + "max_context_length": self.max_context_length, + "max_length": self.max_length, + "rep_pen": self.rep_pen, + "rep_pen_range": self.rep_pen_range, + "rep_pen_slope": self.rep_pen_slope, + "temperature": self.temperature, + "tfs": self.tfs, + "top_a": self.top_a, + "top_p": self.top_p, + "top_k": self.top_k, + "typical": self.typical, + } + + if stop: + data["stop_sequence"] = stop + + return data + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import KoboldApiLLM + + llm = KoboldApiLLM(endpoint="http://localhost:5000") + llm("Write a story about dragons.") + """ + data = self._get_parameters(prompt, stop) + + response = requests.post( + f"{clean_url(self.endpoint)}/api/v1/generate", json=data + ) + + response.raise_for_status() + json_response = response.json() + + if ( + "results" in json_response + and len(json_response["results"]) > 0 + and "text" in json_response["results"][0] + ): + text = json_response["results"][0]["text"].strip() + + if stop is not None: + for sequence in stop: + if text.endswith(sequence): + text = text[: -len(sequence)].rstrip() + + return text + else: + raise ValueError( + f"Unexpected response format from Kobold API: {json_response}" + ) + + # New function to call KoboldAI API asynchronously + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + channel_id: Optional[str] = None, + **kwargs: Any, + ) -> str: + """Call the API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import KoboldApiLLM + + llm = KoboldApiLLM(endpoint="http://localhost:5000") + llm("Write a story about dragons.") + """ + if self.is_koboldcpp: + # Generate a random 10 character genkey + genkey = "".join(random.choices(string.ascii_uppercase + string.digits, k=10)) + print(f"genkey: {genkey}") + + # Store genkeys to dict mapped to channel ID + self.genkeys[channel_id] = genkey + data = self._get_parameters(prompt, stop) + data["genkey"] = genkey + + else: + # Normal for KoboldAI, genkey is not required + data = self._get_parameters(prompt, stop) + + # Use aiohttp to call KoboldAI API asynchronously to prevent blocking + async with aiohttp.ClientSession() as session: + async with session.post(f"{clean_url(self.endpoint)}/api/v1/generate", json=data) as response: + + response.raise_for_status() + json_response = await response.json() + + if ( + "results" in json_response + and len(json_response["results"]) > 0 + and "text" in json_response["results"][0] + ): + text = json_response["results"][0]["text"].strip() + + if stop is not None: + for sequence in stop: + if text.endswith(sequence): + text = text[: -len(sequence)].rstrip() + + return text + else: + raise ValueError( + f"Unexpected response format from Kobold API: {json_response}" + ) + + def check_version(self) -> float: + """Check the version of the koboldcpp API. To distinguish between KoboldAI and koboldcpp""" + try: + response = requests.get(f"{clean_url(self.endpoint)}/api/extra/version") + response.raise_for_status() + json_response = response.json() + self.is_koboldcpp = True + print("The endpoint is running koboldcpp instead of KoboldAI. If you use multiple channel IDs, please pass '--multiuser' to koboldcpp.") + return float(json_response["version"]) + except: + # Try fetching KoboldAI version + try: + response = requests.get(f"{clean_url(self.endpoint)}/api/v1/version") + response.raise_for_status() + json_response = response.json() + self.is_koboldcpp = False + print("The endpoint is running KoboldAI instead of koboldcpp.") + return 0.0 + except: + raise ValueError("The endpoint is not running KoboldAI or koboldcpp.") + + + async def _stop(self, channel_id): + """Send abort request to stop ongoing AI generation. + This only applies to koboldcpp. Official KoboldAI API does not support this. + """ + + # Check genkey before cancelling + if channel_id in self.genkeys: + genkey = self.genkeys[channel_id] + + json = {"genkey": genkey} + + try: + response = requests.post(f"{clean_url(self.endpoint)}/api/extra/abort", json=json) + if response.status_code == 200 and response.json()["success"] == True: + print(f"Successfully aborted AI generation for channel ID of {channel_id}, with genkey: {genkey}") + else: + print("Error aborting AI generation.") + + except Exception as e: + print(f"Error aborting AI generation: {e}") \ No newline at end of file diff --git a/helpers/textgen.py b/helpers/textgen.py index 1365b49..4212a00 100644 --- a/helpers/textgen.py +++ b/helpers/textgen.py @@ -3,8 +3,13 @@ from typing import Any, Dict, Iterator, List, Optional import requests +import asyncio +import aiohttp -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.pydantic_v1 import Field from langchain.schema.output import GenerationChunk @@ -217,7 +222,7 @@ def _call( print(params["stopping_strings"]) # TODO: Remove this line request = params.copy() request["prompt"] = prompt - print(request) # TODO: Remove this line + #print(request) # TODO: Remove this line response = requests.post(url, json=request) if response.status_code == 200: @@ -229,6 +234,59 @@ def _call( return result + # Implement _acall function from LangChain example github + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call the textgen web API and return the output. + + Args: + prompt: The prompt to use for generation. + stop: A list of strings to stop generation when encountered. + + Returns: + The generated text. + + Example: + .. code-block:: python + + from langchain.llms import TextGen + llm = TextGen(model_url="http://localhost:5000") + llm("Write a story about llamas.") + """ + if self.streaming: + combined_text_output = "" + async for chunk in self._astream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + combined_text_output += chunk.text + result = combined_text_output + + else: + # Use aiohttp to call textgen API asynchronously to prevent blocking + async with aiohttp.ClientSession() as session: + url = f"{self.model_url}/api/v1/generate" + params = self._get_parameters(stop) + params["stopping_strings"] = params.pop( + "stop" + ) # Rename 'stop' to 'stopping_strings' + request = params.copy() + request["prompt"] = prompt + #response = requests.post(url, json=request) + + async with session.post(url, json=request) as response: + if response.status == 200: + result = (await response.json())["results"][0]["text"] + else: + print(f"ERROR: response: {response}") + result = "" + + return result + def _stream( self, prompt: str,