Skip to content

A few improvements #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 95 additions & 9 deletions services/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class BaseClient(ABC):
"""Base class for all clients"""

api_type: str = None
system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, no need to include any other explanation. Do not put completed command in a code block."
system_prompt = "You are a zsh shell expert, please help me complete the following command, you should only output the completed command, do not include any explanation, comments or any other text that is not part of the command. Do not put completed command in a code block."

@abstractmethod
def get_completion(self, full_command: str) -> str:
Expand Down Expand Up @@ -54,9 +54,14 @@ def get_completion(self, full_command: str) -> str:
model=self.config["model"],
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": "# list all files with all attributes in current folder"},
{"role": "assistant", "content": "ls -alhi"},
{"role": "user", "content": "# go one directory up"},
{"role": "assistant", "content": "cd .."},
{"role": "user", "content": full_command},
],
temperature=float(self.config.get("temperature", 1.0)),

temperature=float(self.config.get("temperature", 0.0)),
)
return response.choices[0].message.content

Expand Down Expand Up @@ -87,9 +92,25 @@ def __init__(self, config: dict):
self.model = genai.GenerativeModel(self.config["model"])

def get_completion(self, full_command: str) -> str:
chat = self.model.start_chat(history=[])
prompt = f"{self.system_prompt}\n\n{full_command}"
response = chat.send_message(prompt)
chat = self.model.start_chat(history=[
{
"role": "user",
"parts": [f"{self.system_prompt}\n\n# list all files with all attributes in current folder"]
},
{
"role": "model",
"parts": ["ls -alhi"]
},
{
"role": "user",
"parts": ["# go one directory up"]
},
{
"role": "model",
"parts": ["cd .."]
}
])
response = chat.send_message(full_command)
return response.text


Expand Down Expand Up @@ -125,9 +146,13 @@ def get_completion(self, full_command: str) -> str:
model=self.config["model"],
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": "# list all files with all attributes in current folder"},
{"role": "assistant", "content": "ls -alhi"},
{"role": "user", "content": "# go one directory up"},
{"role": "assistant", "content": "cd .."},
{"role": "user", "content": full_command},
],
temperature=float(self.config.get("temperature", 1.0)),
temperature=float(self.config.get("temperature", 0.0)),
)
return response.choices[0].message.content

Expand Down Expand Up @@ -164,9 +189,13 @@ def get_completion(self, full_command: str) -> str:
model=self.config["model"],
messages=[
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": "# list all files with all attributes in current folder"},
{"role": "assistant", "content": "ls -alhi"},
{"role": "user", "content": "# go one directory up"},
{"role": "assistant", "content": "cd .."},
{"role": "user", "content": full_command},
],
temperature=float(self.config.get("temperature", 1.0)),
temperature=float(self.config.get("temperature", 0.0)),
)
return response.choices[0].message.content

Expand Down Expand Up @@ -213,6 +242,10 @@ def get_completion(self, full_command: str) -> str:
import json

messages = [
{"role": "user", "content": "# list all files with all attributes in current folder"},
{"role": "assistant", "content": "ls -alhi"},
{"role": "user", "content": "# go one directory up"},
{"role": "assistant", "content": "cd .."},
{"role": "user", "content": full_command}
]

Expand All @@ -223,7 +256,7 @@ def get_completion(self, full_command: str) -> str:
"max_tokens": 1000,
"system": self.system_prompt,
"messages": messages,
"temperature": float(self.config.get("temperature", 1.0))
"temperature": float(self.config.get("temperature", 0.0))
}
else:
raise ValueError(f"Unsupported model: {self.config['model']}")
Expand All @@ -237,9 +270,60 @@ def get_completion(self, full_command: str) -> str:
return response_body["content"][0]["text"]


class OllamaClient(BaseClient):
"""
config keys:
- api_type="ollama"
- base_url (optional): defaults to "http://localhost:11434"
- model (optional): defaults to "llama3.2" or environment variable OLLAMA_DEFAULT_MODEL
- temperature (optional): defaults to 1.0.
"""

api_type = "ollama"
default_model = os.getenv("OLLAMA_DEFAULT_MODEL", "llama3.2")

def __init__(self, config: dict):
try:
import ollama
except ImportError:
print(
"Ollama library is not installed. Please install it using 'pip install ollama'"
)
sys.exit(1)

self.config = config
self.config["model"] = self.config.get("model", self.default_model)

# Create ollama client with custom host if specified
if "base_url" in self.config:
self.client = ollama.Client(host=self.config["base_url"])
else:
self.client = ollama.Client()

def get_completion(self, full_command: str) -> str:
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": "# list all files with all attributes in current folder"},
{"role": "assistant", "content": "ls -alhi"},
{"role": "user", "content": "# go one directory up"},
{"role": "assistant", "content": "cd .."},
{"role": "user", "content": full_command}
]

response = self.client.chat(
model=self.config["model"],
messages=messages,
options={
"temperature": float(self.config.get("temperature", 0.0))
},
think=True
)

return response["message"]["content"]


class ClientFactory:
api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type]
api_types = [OpenAIClient.api_type, GoogleGenAIClient.api_type, GroqClient.api_type, MistralClient.api_type, AmazonBedrock.api_type, OllamaClient.api_type]

@classmethod
def create(cls):
Expand All @@ -263,6 +347,8 @@ def create(cls):
return MistralClient(config)
case AmazonBedrock.api_type:
return AmazonBedrock(config)
case OllamaClient.api_type:
return OllamaClient(config)
case _:
raise KeyError(
f"Specified API type {api_type} is not one of the supported services {cls.api_types}"
Expand Down