Skip to content

Add Support for Gemini Models – Leverage Larger Context Windows and Cost Benefits #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
from tenacity.wait import wait_exponential

from camel.agents import BaseAgent
from camel.configs import ChatGPTConfig
from camel.configs import ChatGPTConfig, GeminiConfig
from camel.messages import ChatMessage, MessageType, SystemMessage
from camel.model_backend import ModelBackend, ModelFactory
from camel.typing import ModelType, RoleType
from camel.utils import (
get_model_token_limit,
num_tokens_from_messages,
openai_api_key_required,
is_gemini_model
)
from chatdev.utils import log_visualize
try:
Expand Down Expand Up @@ -97,7 +98,10 @@ def __init__(
self.role_name: str = system_message.role_name
self.role_type: RoleType = system_message.role_type
self.model: ModelType = (model if model is not None else ModelType.GPT_3_5_TURBO)
self.model_config: ChatGPTConfig = model_config or ChatGPTConfig()
if is_gemini_model(model):
self.model_config = GeminiConfig()
else:
self.model_config: ChatGPTConfig = model_config or ChatGPTConfig()
self.model_token_limit: int = get_model_token_limit(self.model)
self.message_window_size: Optional[int] = message_window_size
self.model_backend: ModelBackend = ModelFactory.create(self.model, self.model_config.__dict__)
Expand Down
51 changes: 51 additions & 0 deletions camel/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,54 @@ class ChatGPTConfig:
frequency_penalty: float = 0.0
logit_bias: Dict = field(default_factory=dict)
user: str = ""


class GeminiConfig:
r"""Defines the parameters for generating chat completions using the
for Gemini Models.

Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.2`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
n (int, optional): How many chat completion choices to generate for
each input message. ()default: :obj:`1`)
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): Up to :obj:`4` sequences where the API
will stop generating further tokens. (default: :obj:`None`)
max_tokens (int, optional): The maximum number of tokens to generate
in the chat completion. The total length of input tokens and
generated tokens is limited by the model's context length.
(default: :obj:`None`)
presence_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on whether
they appear in the text so far, increasing the model's likelihood
to talk about new topics. See more information about frequency and
presence penalties. (default: :obj:`0.0`)
frequency_penalty (float, optional): Number between :obj:`-2.0` and
:obj:`2.0`. Positive values penalize new tokens based on their
existing frequency in the text so far, decreasing the model's
likelihood to repeat the same line verbatim. See more information
about frequency and presence penalties. (default: :obj:`0.0`)
user (str, optional): A unique identifier representing your end-user,
which can help OpenAI to monitor and detect abuse.
(default: :obj:`""`)
"""
temperature: float = 0.2 # openai default: 1.0
top_p: float = 1.0
n: int = 1
stream: bool = False
stop: Optional[Union[str, Sequence[str]]] = None
max_tokens: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
#logit_bias: Dict = field(default_factory=dict)
user: str = ""
50 changes: 47 additions & 3 deletions camel/model_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from chatdev.statistics import prompt_cost
from chatdev.utils import log_visualize

import google.generativeai as genai
from camel import utils

try:
from openai.types.chat import ChatCompletion

Expand All @@ -36,6 +39,11 @@
else:
BASE_URL = None

# we configure the google.generativeai using the api key.
# we must set our gemini api key as a value to OPENAI_API_KEY in the enviroment.
genai.configure(api_key=OPENAI_API_KEY) # so its okay to reference as OPENAI_API_KEY here



class ModelBackend(ABC):
r"""Base class for different model backends.
Expand Down Expand Up @@ -65,8 +73,21 @@ def __init__(self, model_type: ModelType, model_config_dict: Dict) -> None:

def run(self, *args, **kwargs):
string = "\n".join([message["content"] for message in kwargs["messages"]])
encoding = tiktoken.encoding_for_model(self.model_type.value)
num_prompt_tokens = len(encoding.encode(string))

# if the model is a gemini model then we use the count_tokens function in the
# google.generativeai.GenerativeModel to get the token required for the message
# remove any 'None' key-value pair in the self.model_config_dict

if utils.is_gemini_model(self.model_type):
gemini_model = genai.GenerativeModel(self.model_type.value)
num_prompt_tokens = gemini_model.count_tokens(string).total_tokens

# Filter out None values from the config dictionary
self.model_config_dict = {k:v for k, v in self.model_config_dict.items() if v is not None}
else:
encoding = tiktoken.encoding_for_model(self.model_type.value)
num_prompt_tokens = len(encoding.encode(string))

gap_between_send_receive = 15 * len(kwargs["messages"])
num_prompt_tokens += gap_between_send_receive

Expand All @@ -83,6 +104,7 @@ def run(self, *args, **kwargs):
)

num_max_token_map = {
# OpenAI models
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-0613": 4096,
Expand All @@ -92,7 +114,15 @@ def run(self, *args, **kwargs):
"gpt-4-32k": 32768,
"gpt-4-turbo": 100000,
"gpt-4o": 4096, #100000
"gpt-4o-mini": 16384, #100000
"gpt-4o-mini": 16384, #100000,

# gemini models
"gemini-2.0-flash-001" : 1056768,
"gemini-2.0-flash-lite-preview-02-05": 1056768,
"gemini-1.5-flash": 1056768,
"gemini-1.5-flash-8b": 1056768,
"gemini-1.5-pro" :2105344,

}
num_max_token = num_max_token_map[self.model_type.value]
num_max_completion_tokens = num_max_token - num_prompt_tokens
Expand Down Expand Up @@ -126,6 +156,14 @@ def run(self, *args, **kwargs):
"gpt-4-turbo": 100000,
"gpt-4o": 4096, #100000
"gpt-4o-mini": 16384, #100000

# gemini models
"gemini-2.0-flash-001" : 1056768,
"gemini-2.0-flash-lite-preview-02-05": 1056768,
"gemini-1.5-flash": 1056768,
"gemini-1.5-flash-8b": 1056768,
"gemini-1.5-pro" :2105344,

}
num_max_token = num_max_token_map[self.model_type.value]
num_max_completion_tokens = num_max_token - num_prompt_tokens
Expand Down Expand Up @@ -188,6 +226,12 @@ def create(model_type: ModelType, model_config_dict: Dict) -> ModelBackend:
ModelType.GPT_4_TURBO_V,
ModelType.GPT_4O,
ModelType.GPT_4O_MINI,

ModelType.GEMINI_2_0_FLASH_001,
ModelType.GEMINI_2_0_FLASH_LITE_PREVIEW_02_05,
ModelType.GEMINI_1_5_FLASH,
ModelType.GEMINI_1_5_FLASH_8B,
ModelType.GEMINI_1_5_PRO,
None
}:
model_class = OpenAIModel
Expand Down
8 changes: 8 additions & 0 deletions camel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class RoleType(Enum):


class ModelType(Enum):
# Gpt models
GPT_3_5_TURBO = "gpt-3.5-turbo-16k-0613"
GPT_3_5_TURBO_NEW = "gpt-3.5-turbo-16k"
GPT_4 = "gpt-4"
Expand All @@ -55,6 +56,13 @@ class ModelType(Enum):

STUB = "stub"

# Gemini Models
GEMINI_2_0_FLASH_001 = "gemini-2.0-flash-001"
GEMINI_2_0_FLASH_LITE_PREVIEW_02_05 = "gemini-2.0-flash-lite-preview-02-05"
GEMINI_1_5_FLASH = "gemini-1.5-flash"
GEMINI_1_5_FLASH_8B = "gemini-1.5-flash-8b"
GEMINI_1_5_PRO = "gemini-1.5-pro"

@property
def value_for_tiktoken(self):
return self.value if self.name != "STUB" else "gpt-3.5-turbo-16k-0613"
Expand Down
69 changes: 68 additions & 1 deletion camel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@

import time

import google.generativeai as genai
# it is required to configure the google.generativeai api with a GEMINI_API_KEY.
# Because we are using the openAI api, and provided we have to set an OPENAI_API_KEY using our
# gemini key ...
genai.configure(api_key=os.environ['OPENAI_API_KEY']) # so its okay to reference as OPENAI_API_KEY here


def count_tokens_openai_chat_models(
messages: List[OpenAIMessage],
Expand All @@ -53,6 +59,31 @@ def count_tokens_openai_chat_models(
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens

def count_tokens_genai_chat_models(
messages: List[OpenAIMessage],
genai_model,
) -> int:
r"""Counts the number of tokens required to generate a google.GenerativeModel chat based
on a given list of messages.

Args:
messages (List[OpenAIMessage]): The list of messages.
genai_model google.GenerativeModel: The model instance

Returns:
int: The number of tokens required.
"""
num_tokens = 0
for message in messages:
# message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4
for key, value in message.items():
num_tokens += genai_model.count_tokens(value).total_tokens
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always 1 token
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens


def num_tokens_from_messages(
messages: List[OpenAIMessage],
Expand Down Expand Up @@ -91,9 +122,19 @@ def num_tokens_from_messages(
ModelType.GPT_4_TURBO_V,
ModelType.GPT_4O,
ModelType.GPT_4O_MINI,
ModelType.STUB
ModelType.STUB,

}:
return count_tokens_openai_chat_models(messages, encoding)
elif model in {
ModelType.GEMINI_2_0_FLASH_001,
ModelType.GEMINI_2_0_FLASH_LITE_PREVIEW_02_05,
ModelType.GEMINI_1_5_FLASH,
ModelType.GEMINI_1_5_FLASH_8B,
ModelType.GEMINI_1_5_PRO,
}:
gemini_model = genai.GenerativeModel(model.value)
return count_tokens_genai_chat_models(messages, gemini_model)
else:
raise NotImplementedError(
f"`num_tokens_from_messages`` is not presently implemented "
Expand All @@ -114,6 +155,7 @@ def get_model_token_limit(model: ModelType) -> int:
Returns:
int: The maximum token limit for the given model.
"""
#OpenAI models
if model == ModelType.GPT_3_5_TURBO:
return 16384
elif model == ModelType.GPT_3_5_TURBO_NEW:
Expand All @@ -130,6 +172,18 @@ def get_model_token_limit(model: ModelType) -> int:
return 128000
elif model == ModelType.GPT_4O_MINI:
return 128000

#GeminiAI models
elif model == ModelType.GEMINI_2_0_FLASH_001:
return 1056768
elif model == ModelType.GEMINI_2_0_FLASH_LITE_PREVIEW_02_05:
return 1056768
elif model == ModelType.GEMINI_1_5_FLASH:
return 1056768
elif model == ModelType.GEMINI_1_5_FLASH_8B:
return 1056768
elif model == ModelType.GEMINI_1_5_PRO:
return 2105344
else:
raise ValueError("Unknown model type")

Expand Down Expand Up @@ -233,3 +287,16 @@ def download_tasks(task: TaskType, folder_path: str) -> None:

# Delete the zip file
os.remove(zip_file_path)

def is_gemini_model(model_name:ModelType) -> bool:
"""Check if a ModelType is a gemini model
"""
if model_name.value in [
"gemini-2.0-flash-001",
"gemini-2.0-flash-lite-preview-02-05",
"gemini-1.5-flash",
"gemini-1.5-flash-8b",
"gemini-1.5-pro",]:

return True
return False
8 changes: 7 additions & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_config(company):
parser.add_argument('--name', type=str, default="Gomoku",
help="Name of software, your software will be generated in WareHouse/name_org_timestamp")
parser.add_argument('--model', type=str, default="GPT_3_5_TURBO",
help="GPT Model, choose from {'GPT_3_5_TURBO', 'GPT_4', 'GPT_4_TURBO', 'GPT_4O', 'GPT_4O_MINI'}")
help="GPT Model: choose from {'GPT_3_5_TURBO', 'GPT_4', 'GPT_4_TURBO', 'GPT_4O', 'GPT_4O_MINI'},\nGemini Models: choose from {'GEMINI_1_5_PRO','GEMINI_2_0_FLASH_001, 'GEMINI_2_0_FLASH_LITE_PREVIEW_02_05', 'GEMINI_1_5_FLASH', 'GEMINI_1_5_FLASH_8B'}")
parser.add_argument('--path', type=str, default="",
help="Your file directory, ChatDev will build upon your software in the Incremental mode")
args = parser.parse_args()
Expand All @@ -97,6 +97,12 @@ def get_config(company):
# 'GPT_4_TURBO_V': ModelType.GPT_4_TURBO_V
'GPT_4O': ModelType.GPT_4O,
'GPT_4O_MINI': ModelType.GPT_4O_MINI,

'GEMINI_1_5_PRO' : ModelType.GEMINI_1_5_PRO,
'GEMINI_2_0_FLASH_001':ModelType.GEMINI_2_0_FLASH_001,
'GEMINI_2_0_FLASH_LITE_PREVIEW_02_05': ModelType.GEMINI_2_0_FLASH_LITE_PREVIEW_02_05,
'GEMINI_1_5_FLASH' : ModelType.GEMINI_1_5_FLASH,
'GEMINI_1_5_FLASH_8B':ModelType.GEMINI_1_5_FLASH_8B
}
if openai_new_api:
args2type['GPT_3_5_TURBO'] = ModelType.GPT_3_5_TURBO_NEW
Expand Down