diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..955c04d --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +**/.DS_Store +/.venv +**/__pycache__ +**.pyc + diff --git a/global_methods.py b/global_methods.py index 7ada78d..2b008dd 100644 --- a/global_methods.py +++ b/global_methods.py @@ -5,7 +5,7 @@ import sys import os -import google.generativeai as genai +from google import genai from anthropic import Anthropic @@ -17,9 +17,18 @@ def set_anthropic_key(): pass def set_gemini_key(): + # This is no longer needed with the new SDK + # The client will automatically use the GEMINI_API_KEY or GOOGLE_API_KEY environment variable + pass - # Or use `os.getenv('GOOGLE_API_KEY')` to fetch an environment variable. - genai.configure(api_key=os.environ['GOOGLE_API_KEY']) +def get_gemini_client(): + # Get API key from environment variables + api_key = os.environ.get('GEMINI_API_KEY') or os.environ.get('GOOGLE_API_KEY') + if not api_key: + raise ValueError("Please set GEMINI_API_KEY or GOOGLE_API_KEY environment variable") + + # Create and return the client + return genai.Client(api_key=api_key) def set_openai_key(): openai.api_key = os.environ['OPENAI_API_KEY'] @@ -79,10 +88,12 @@ def run_claude(query, max_new_tokens, model_name): return message.content[0].text -def run_gemini(model, content: str, max_tokens: int = 0): - +def run_gemini(client, model_name: str, content: str, max_tokens: int = 0): try: - response = model.generate_content(content) + response = client.models.generate_content( + model=model_name, + contents=content + ) return response.text except Exception as e: print(f'{type(e).__name__}: {e}') diff --git a/requirements.txt b/requirements.txt index dcda28d..37970ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,316 +1,224 @@ -# This file may be used to create an environment using: -# $ conda create --name --file -# platform: linux-64 -_libgcc_mutex=0.1=main -_openmp_mutex=5.1=1_gnu -accelerate=0.24.1=pypi_0 -aiofiles=23.2.1=pypi_0 -aiohttp=3.8.4=pypi_0 -aiosignal=1.3.1=pypi_0 -altair=5.1.2=pypi_0 -anthropic=0.32.0=pypi_0 -antlr4-python3-runtime=4.9.3=pypi_0 -anyio=3.7.1=pypi_0 -appdirs=1.4.4=pypi_0 -argon2-cffi=23.1.0=pypi_0 -argon2-cffi-bindings=21.2.0=pypi_0 -arrow=1.3.0=pypi_0 -asttokens=2.4.1=pypi_0 -async-lru=2.0.4=pypi_0 -async-timeout=4.0.2=pypi_0 -attrs=22.2.0=pypi_0 -babel=2.13.1=pypi_0 -beautifulsoup4=4.12.2=pypi_0 -bert-score=0.3.13=pypi_0 -blas=1.0=mkl -bleach=6.1.0=pypi_0 -blis=0.7.11=pypi_0 -braceexpand=0.1.7=pypi_0 -brotli-python=1.0.9=py39h6a678d5_7 -bzip2=1.0.8=h7b6447c_0 -ca-certificates=2023.08.22=h06a4308_0 -cachetools=5.4.0=pypi_0 -catalogue=2.0.10=pypi_0 -cchardet=2.1.7=pypi_0 -certifi=2023.7.22=py39h06a4308_0 -cffi=1.15.1=py39h5eee18b_3 -chardet=5.1.0=pypi_0 -charset-normalizer=2.0.4=pyhd3eb1b0_0 -click=8.1.7=pypi_0 -comm=0.1.4=pypi_0 -confection=0.1.3=pypi_0 -contourpy=1.0.7=pypi_0 -cryptography=41.0.3=py39hdda0065_0 -cuda-cudart=11.7.99=0 -cuda-cupti=11.7.101=0 -cuda-libraries=11.7.1=0 -cuda-nvrtc=11.7.99=0 -cuda-nvtx=11.7.91=0 -cuda-runtime=11.7.1=0 -cycler=0.11.0=pypi_0 -cymem=2.0.8=pypi_0 -debugpy=1.8.0=pypi_0 -decorator=5.1.1=pypi_0 -decord=0.6.0=pypi_0 -defusedxml=0.7.1=pypi_0 -diffusers=0.21.4=pypi_0 -distro=1.8.0=pypi_0 -docker-pycreds=0.4.0=pypi_0 -et-xmlfile=1.1.0=pypi_0 -exceptiongroup=1.1.3=pypi_0 -executing=2.0.1=pypi_0 -fastapi=0.104.1=pypi_0 -fastjsonschema=2.18.1=pypi_0 -ffmpeg=4.3=hf484d3e_0 -ffmpy=0.3.1=pypi_0 -filelock=3.9.0=py39h06a4308_0 -fonttools=4.38.0=pypi_0 -fqdn=1.5.1=pypi_0 -freetype=2.12.1=h4a9f257_0 -frozenlist=1.3.3=pypi_0 -fsspec=2023.10.0=pypi_0 -ftfy=6.1.1=pypi_0 -gdown=4.7.1=pypi_0 -giflib=5.2.1=h5eee18b_3 -gitdb=4.0.11=pypi_0 -gitpython=3.1.40=pypi_0 -gmp=6.2.1=h295c915_3 -gmpy2=2.1.2=py39heeb90bb_0 -gnutls=3.6.15=he1e5248_0 -google-ai-generativelanguage=0.6.6=pypi_0 -google-api-core=2.19.1=pypi_0 -google-api-python-client=2.140.0=pypi_0 -google-auth=2.33.0=pypi_0 -google-auth-httplib2=0.2.0=pypi_0 -google-generativeai=0.7.2=pypi_0 -googleapis-common-protos=1.63.2=pypi_0 -gradio=3.24.1=pypi_0 -gradio-client=0.0.8=pypi_0 -grpcio=1.65.4=pypi_0 -grpcio-status=1.62.3=pypi_0 -h11=0.14.0=pypi_0 -httpcore=0.18.0=pypi_0 -httplib2=0.22.0=pypi_0 -httpx=0.25.0=pypi_0 -huggingface-hub=0.17.3=pypi_0 -idna=3.4=py39h06a4308_0 -importlib-metadata=6.8.0=pypi_0 -importlib-resources=5.12.0=pypi_0 -intel-openmp=2023.1.0=hdb19cb5_46305 -iopath=0.1.10=pypi_0 -ipykernel=6.26.0=pypi_0 -ipython=8.17.2=pypi_0 -isoduration=20.11.0=pypi_0 -jedi=0.19.1=pypi_0 -jinja2=3.1.2=py39h06a4308_0 -jiter=0.5.0=pypi_0 -joblib=1.3.2=pypi_0 -jpeg=9e=h5eee18b_1 -json5=0.9.14=pypi_0 -jsonpointer=2.4=pypi_0 -jsonschema=4.19.2=pypi_0 -jsonschema-specifications=2023.7.1=pypi_0 -jupyter-client=8.5.0=pypi_0 -jupyter-core=5.5.0=pypi_0 -jupyter-events=0.8.0=pypi_0 -jupyter-lsp=2.2.0=pypi_0 -jupyter-server=2.9.1=pypi_0 -jupyter-server-terminals=0.4.4=pypi_0 -jupyterlab=4.0.7=pypi_0 -jupyterlab-pygments=0.2.2=pypi_0 -jupyterlab-server=2.25.0=pypi_0 -kiwisolver=1.4.4=pypi_0 -lame=3.100=h7b6447c_0 -langcodes=3.3.0=pypi_0 -lcms2=2.12=h3be6417_0 -ld_impl_linux-64=2.38=h1181459_1 -lerc=3.0=h295c915_0 -libcublas=11.10.3.66=0 -libcufft=10.7.2.124=h4fbf590_0 -libcufile=1.8.0.34=0 -libcurand=10.3.4.52=0 -libcusolver=11.4.0.1=0 -libcusparse=11.7.4.91=0 -libdeflate=1.17=h5eee18b_1 -libffi=3.4.4=h6a678d5_0 -libgcc-ng=11.2.0=h1234567_1 -libgomp=11.2.0=h1234567_1 -libiconv=1.16=h7f8727e_2 -libidn2=2.3.4=h5eee18b_0 -libnpp=11.7.4.75=0 -libnvjpeg=11.8.0.2=0 -libpng=1.6.39=h5eee18b_0 -libstdcxx-ng=11.2.0=h1234567_1 -libtasn1=4.19.0=h5eee18b_0 -libtiff=4.5.1=h6a678d5_0 -libunistring=0.9.10=h27cfd23_0 -libwebp=1.3.2=h11a3e52_0 -libwebp-base=1.3.2=h5eee18b_0 -lightning=2.1.0=pypi_0 -lightning-utilities=0.9.0=pypi_0 -linkify-it-py=2.0.2=pypi_0 -llvmlite=0.41.1=pypi_0 -lz4-c=1.9.4=h6a678d5_0 -markdown-it-py=2.2.0=pypi_0 -markupsafe=2.1.1=py39h7f8727e_0 -matplotlib=3.7.0=pypi_0 -matplotlib-inline=0.1.6=pypi_0 -mdit-py-plugins=0.3.3=pypi_0 -mdurl=0.1.2=pypi_0 -mistune=3.0.2=pypi_0 -mkl=2023.1.0=h213fc3f_46343 -mkl-service=2.4.0=py39h5eee18b_1 -mkl_fft=1.3.8=py39h5eee18b_0 -mkl_random=1.2.4=py39hdb19cb5_0 -mpc=1.1.0=h10f8cd9_1 -mpfr=4.0.2=hb69a4c5_1 -mpmath=1.3.0=py39h06a4308_0 -multidict=6.0.4=pypi_0 -murmurhash=1.0.10=pypi_0 -nbclient=0.8.0=pypi_0 -nbconvert=7.10.0=pypi_0 -nbformat=5.9.2=pypi_0 -ncurses=6.4=h6a678d5_0 -nest-asyncio=1.5.8=pypi_0 -nettle=3.7.3=hbbd107a_1 -networkx=3.1=py39h06a4308_0 -nltk=3.8.1=pypi_0 -notebook=7.0.6=pypi_0 -notebook-shim=0.2.3=pypi_0 -numba=0.58.1=pypi_0 -numpy=1.26.0=py39h5f9d8c6_0 -numpy-base=1.26.0=py39hb5e798b_0 -omegaconf=2.3.0=pypi_0 -open-clip-torch=2.23.0=pypi_0 -openai=0.28.0=pypi_0 -opencv-python=4.8.1.78=pypi_0 -openh264=2.1.1=h4ff587b_0 -openjpeg=2.4.0=h3ad879b_0 -openpyxl=3.1.2=pypi_0 -openssl=3.0.11=h7f8727e_2 -orjson=3.9.10=pypi_0 -overrides=7.4.0=pypi_0 -packaging=23.0=pypi_0 -pandas=2.1.2=pypi_0 -pandocfilters=1.5.0=pypi_0 -parso=0.8.3=pypi_0 -pathtools=0.1.2=pypi_0 -pathy=0.10.3=pypi_0 -peft=0.5.0=pypi_0 -pexpect=4.8.0=pypi_0 -pillow=10.0.1=py39ha6cbd5a_0 -pip=23.3=py39h06a4308_0 -platformdirs=3.11.0=pypi_0 -portalocker=2.8.2=pypi_0 -preshed=3.0.9=pypi_0 -prometheus-client=0.18.0=pypi_0 -prompt-toolkit=3.0.39=pypi_0 -proto-plus=1.24.0=pypi_0 -protobuf=4.25.0=pypi_0 -psutil=5.9.4=pypi_0 -ptyprocess=0.7.0=pypi_0 -pure-eval=0.2.2=pypi_0 -pyasn1=0.6.0=pypi_0 -pyasn1-modules=0.4.0=pypi_0 -pycocoevalcap=1.2=pypi_0 -pycocotools=2.0.6=pypi_0 -pycparser=2.21=pyhd3eb1b0_0 -pydantic=1.10.13=pypi_0 -pydub=0.25.1=pypi_0 -pygments=2.16.1=pypi_0 -pynndescent=0.5.10=pypi_0 -pyopenssl=23.2.0=py39h06a4308_0 -pyparsing=3.0.9=pypi_0 -pysocks=1.7.1=py39h06a4308_0 -python=3.9.18=h955ad1f_0 -python-dateutil=2.8.2=pypi_0 -python-json-logger=2.0.7=pypi_0 -python-multipart=0.0.6=pypi_0 -pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0 -pytorch-cuda=11.7=h778d358_5 -pytorch-fid=0.3.0=pypi_0 -pytorch-lightning=2.1.0=pypi_0 -pytorch-mutex=1.0=cuda -pytz=2023.3.post1=pypi_0 -pyyaml=6.0=pypi_0 -pyzmq=25.1.1=pypi_0 -readline=8.2=h5eee18b_0 -referencing=0.30.2=pypi_0 -regex=2022.10.31=pypi_0 -requests=2.31.0=py39h06a4308_0 -rfc3339-validator=0.1.4=pypi_0 -rfc3986-validator=0.1.1=pypi_0 -rouge=1.0.1=pypi_0 -rpds-py=0.10.6=pypi_0 -rsa=4.9=pypi_0 -safetensors=0.4.0=pypi_0 -scikit-learn=1.3.2=pypi_0 -scipy=1.11.3=pypi_0 -seaborn=0.13.0=pypi_0 -semantic-version=2.10.0=pypi_0 -send2trash=1.8.2=pypi_0 -sentence-transformers=2.2.2=pypi_0 -sentencepiece=0.1.99=pypi_0 -sentry-sdk=1.34.0=pypi_0 -setproctitle=1.3.3=pypi_0 -setuptools=68.0.0=py39h06a4308_0 -six=1.16.0=pypi_0 -smart-open=6.4.0=pypi_0 -smmap=5.0.1=pypi_0 -sniffio=1.3.0=pypi_0 -soupsieve=2.5=pypi_0 -spacy=3.5.1=pypi_0 -spacy-legacy=3.0.12=pypi_0 -spacy-loggers=1.0.5=pypi_0 -sqlite=3.41.2=h5eee18b_0 -srsly=2.4.8=pypi_0 -stack-data=0.6.3=pypi_0 -starlette=0.27.0=pypi_0 -sympy=1.11.1=py39h06a4308_0 -tbb=2021.10.0=pypi_0 -tenacity=8.2.2=pypi_0 -terminado=0.17.1=pypi_0 -thinc=8.1.12=pypi_0 -threadpoolctl=3.2.0=pypi_0 -tiktoken=0.5.2=pypi_0 -timm=0.6.13=pypi_0 -tinycss2=1.2.1=pypi_0 -tk=8.6.12=h1ccaba5_0 -tokenizers=0.14.1=pypi_0 -tomli=2.0.1=pypi_0 -toolz=0.12.0=pypi_0 -torch-fidelity=0.3.0=pypi_0 -torchaudio=2.0.2=py39_cu117 -torchmetrics=1.2.0=pypi_0 -torchtriton=2.0.0=py39 -torchvision=0.15.2=py39_cu117 -tornado=6.3.3=pypi_0 -tqdm=4.64.1=pypi_0 -traitlets=5.13.0=pypi_0 -transformers=4.35.0=pypi_0 -typer=0.7.0=pypi_0 -types-python-dateutil=2.8.19.14=pypi_0 -typing-extensions=4.8.0=pypi_0 -tzdata=2023.3=pypi_0 -uc-micro-py=1.0.2=pypi_0 -umap-learn=0.5.4=pypi_0 -uri-template=1.3.0=pypi_0 -uritemplate=4.1.1=pypi_0 -urllib3=1.26.18=py39h06a4308_0 -uvicorn=0.23.2=pypi_0 -wandb=0.15.12=pypi_0 -wasabi=1.1.2=pypi_0 -wcwidth=0.2.9=pypi_0 -webcolors=1.13=pypi_0 -webdataset=0.2.48=pypi_0 -webencodings=0.5.1=pypi_0 -websocket-client=1.6.4=pypi_0 -websockets=12.0=pypi_0 -wheel=0.41.2=py39h06a4308_0 -wordcloud=1.9.3=pypi_0 -xformers=0.0.22.post7=py39_cu11.8.0_pyt2.0.1 -xz=5.4.2=h5eee18b_0 -yarl=1.8.2=pypi_0 -zipp=3.14.0=pypi_0 -zlib=1.2.13=h5eee18b_0 -zstd=1.5.5=hc292b87_0 +accelerate==0.24.1 +aiofiles==23.2.1 +aiohttp==3.8.4 +aiosignal==1.3.1 +altair==5.1.2 +anthropic==0.32.0 +antlr4-python3-runtime==4.9.3 +anyio>=4.8.0,<5.0.0 +appdirs==1.4.4 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.1 +async-lru==2.0.4 +async-timeout==4.0.2 +attrs==22.2.0 +babel==2.13.1 +beautifulsoup4==4.12.2 +bert-score==0.3.13 +bleach==6.1.0 +braceexpand==0.1.7 +brotli==1.0.9 +cachetools==5.4.0 +cchardet==2.1.7 +certifi==2023.7.22 +cffi==1.15.1 +chardet==5.1.0 +charset-normalizer==2.0.4 +click==8.1.7 +comm==0.1.4 +contourpy==1.0.7 +cryptography==41.0.3 +cycler==0.11.0 +debugpy==1.8.0 +decorator==5.1.1 +defusedxml==0.7.1 +diffusers==0.21.4 +distro==1.8.0 +docker-pycreds==0.4.0 +et-xmlfile==1.1.0 +exceptiongroup==1.1.3 +executing==2.0.1 +fastapi>=0.110.0 +fastjsonschema==2.18.1 +ffmpy==0.3.1 +filelock==3.9.0 +fonttools==4.38.0 +fqdn==1.5.1 +frozenlist==1.3.3 +fsspec==2023.10.0 +ftfy==6.1.1 +gdown==4.7.1 +gitdb==4.0.11 +gitpython==3.1.40 +gmpy2==2.1.2 +google-ai-generativelanguage==0.6.6 +google-api-core==2.19.1 +google-api-python-client==2.140.0 +google-auth==2.33.0 +google-auth-httplib2==0.2.0 +google-genai>=0.2.0 +googleapis-common-protos==1.63.2 +gradio==3.24.1 +gradio-client==0.0.8 +grpcio==1.65.4 +grpcio-status==1.62.3 +h11==0.14.0 +httpcore>=1.0.0,<2.0.0 +httplib2==0.22.0 +httpx>=0.28.1,<1.0.0 +huggingface-hub==0.17.3 +idna==3.4 +importlib-metadata==6.8.0 +importlib-resources==5.12.0 +iopath==0.1.10 +ipykernel==6.26.0 +ipython==8.17.2 +isoduration==20.11.0 +jedi==0.19.1 +jinja2==3.1.2 +jiter==0.5.0 +joblib==1.3.2 +json5==0.9.14 +jsonpointer==2.4 +jsonschema==4.19.2 +jsonschema-specifications==2023.7.1 +jupyter-client==8.5.0 +jupyter-core==5.5.0 +jupyter-events==0.8.0 +jupyter-lsp==2.2.0 +jupyter-server==2.9.1 +jupyter-server-terminals==0.4.4 +jupyterlab==4.0.7 +jupyterlab-pygments==0.2.2 +jupyterlab-server==2.25.0 +kiwisolver==1.4.4 +lightning==2.1.0 +lightning-utilities==0.9.0 +linkify-it-py==2.0.2 +llvmlite==0.41.1 +markdown-it-py==2.2.0 +markupsafe==2.1.1 +matplotlib==3.7.0 +matplotlib-inline==0.1.6 +mdit-py-plugins==0.3.3 +mdurl==0.1.2 +mistune==3.0.2 +mpfr +mpmath==1.3.0 +multidict==6.0.4 +nbclient==0.8.0 +nbconvert==7.10.0 +nbformat==5.9.2 +nest-asyncio==1.5.8 +networkx==3.1 +nltk==3.8.1 +notebook==7.0.6 +notebook-shim==0.2.3 +numba==0.58.1 +numpy==1.26.0 +omegaconf==2.3.0 +open-clip-torch==2.23.0 +openai==0.28.0 +opencv-python==4.8.1.78 +openpyxl==3.1.2 +orjson==3.9.10 +overrides==7.4.0 +packaging==23.0 +pandas==2.1.2 +pandocfilters==1.5.0 +parso==0.8.3 +pathtools==0.1.2 +peft==0.5.0 +pexpect==4.8.0 +pillow==10.0.1 +platformdirs==3.11.0 +portalocker==2.8.2 +prometheus-client==0.18.0 +prompt-toolkit==3.0.39 +proto-plus==1.24.0 +protobuf==4.25.0 +psutil==5.9.4 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyasn1==0.6.0 +pyasn1-modules==0.4.0 +pycocoevalcap==1.2 +pycocotools==2.0.6 +pycparser==2.21 +pydantic>=2.0.0,<3.0.0 +pydub==0.25.1 +pygments==2.16.1 +pynndescent==0.5.10 +pyopenssl==23.2.0 +pyparsing==3.0.9 +pysocks==1.7.1 +python-dateutil==2.8.2 +python-json-logger==2.0.7 +python-multipart==0.0.6 +pytz==2023.3.post1 +pyyaml==6.0 +pyzmq==25.1.1 +referencing==0.30.2 +regex==2022.10.31 +requests==2.31.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rouge==1.0.1 +rpds-py==0.10.6 +rsa==4.9 +safetensors==0.4.0 +scikit-learn==1.3.2 +scipy==1.11.3 +seaborn==0.13.0 +semantic-version==2.10.0 +send2trash==1.8.2 +sentence-transformers==2.2.2 +sentencepiece==0.1.99 +sentry-sdk==1.34.0 +setproctitle==1.3.3 +six==1.16.0 +smart-open==6.4.0 +smmap==5.0.1 +sniffio==1.3.0 +soupsieve==2.5 +stack-data==0.6.3 +starlette>=0.36.0 +sympy==1.11.1 +tenacity>=8.2.3,<9.0.0 +terminado==0.17.1 +threadpoolctl==3.2.0 +tiktoken==0.5.2 +timm==0.6.13 +tinycss2==1.2.1 +tokenizers==0.14.1 +tomli==2.0.1 +toolz==0.12.0 +torch-fidelity==0.3.0 +torchaudio==2.0.2 +torchmetrics==1.2.0 +tornado==6.3.3 +tqdm==4.64.1 +traitlets==5.13.0 +transformers==4.35.0 +typer==0.7.0 +types-python-dateutil==2.8.19.14 +typing-extensions==4.8.0 +tzdata==2023.3 +uc-micro-py==1.0.2 +umap-learn==0.5.4 +uri-template==1.3.0 +uritemplate==4.1.1 +urllib3==1.26.18 +uvicorn==0.23.2 +wandb==0.15.12 +wcwidth==0.2.9 +webcolors==1.13 +webdataset==0.2.48 +webencodings==0.5.1 +websocket-client==1.6.4 +websockets>=13.0.0,<15.1.0 +wordcloud==1.9.3 +yarl==1.8.2 +zipp==3.14.0 \ No newline at end of file diff --git a/scripts/evaluate_gemini.sh b/scripts/evaluate_gemini.sh index dadc255..7e0082b 100644 --- a/scripts/evaluate_gemini.sh +++ b/scripts/evaluate_gemini.sh @@ -4,4 +4,4 @@ source scripts/env.sh # Evaluate Gemini Pro python3 task_eval/evaluate_qa.py \ --data-file $DATA_FILE_PATH --out-file $OUT_DIR/$QA_OUTPUT_FILE \ - --model gemini-pro-1.0 --batch-size 20 + --model gemini-2.5-flash-lite-preview-06-17 --batch-size 20 diff --git a/task_eval/__pycache__/__init__.cpython-310.pyc b/task_eval/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index b2be7c0..0000000 Binary files a/task_eval/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/task_eval/__pycache__/__init__.cpython-39.pyc b/task_eval/__pycache__/__init__.cpython-39.pyc deleted file mode 100644 index 78cc45a..0000000 Binary files a/task_eval/__pycache__/__init__.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/claude_utils.cpython-39.pyc b/task_eval/__pycache__/claude_utils.cpython-39.pyc deleted file mode 100644 index e4e9330..0000000 Binary files a/task_eval/__pycache__/claude_utils.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/dpr_qa.cpython-39.pyc b/task_eval/__pycache__/dpr_qa.cpython-39.pyc deleted file mode 100644 index c3c6110..0000000 Binary files a/task_eval/__pycache__/dpr_qa.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/evaluation.cpython-310.pyc b/task_eval/__pycache__/evaluation.cpython-310.pyc deleted file mode 100644 index 71590f0..0000000 Binary files a/task_eval/__pycache__/evaluation.cpython-310.pyc and /dev/null differ diff --git a/task_eval/__pycache__/evaluation.cpython-39.pyc b/task_eval/__pycache__/evaluation.cpython-39.pyc deleted file mode 100644 index 05d6128..0000000 Binary files a/task_eval/__pycache__/evaluation.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/evaluation_stats.cpython-310.pyc b/task_eval/__pycache__/evaluation_stats.cpython-310.pyc deleted file mode 100644 index 28755b3..0000000 Binary files a/task_eval/__pycache__/evaluation_stats.cpython-310.pyc and /dev/null differ diff --git a/task_eval/__pycache__/evaluation_stats.cpython-39.pyc b/task_eval/__pycache__/evaluation_stats.cpython-39.pyc deleted file mode 100644 index 5617497..0000000 Binary files a/task_eval/__pycache__/evaluation_stats.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/gemini_utils.cpython-39.pyc b/task_eval/__pycache__/gemini_utils.cpython-39.pyc deleted file mode 100644 index 1eff433..0000000 Binary files a/task_eval/__pycache__/gemini_utils.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/gpt_utils.cpython-39.pyc b/task_eval/__pycache__/gpt_utils.cpython-39.pyc deleted file mode 100644 index 0e12129..0000000 Binary files a/task_eval/__pycache__/gpt_utils.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/hf_llm_utils.cpython-39.pyc b/task_eval/__pycache__/hf_llm_utils.cpython-39.pyc deleted file mode 100644 index 23a440c..0000000 Binary files a/task_eval/__pycache__/hf_llm_utils.cpython-39.pyc and /dev/null differ diff --git a/task_eval/__pycache__/rag_utils.cpython-39.pyc b/task_eval/__pycache__/rag_utils.cpython-39.pyc deleted file mode 100644 index 6921693..0000000 Binary files a/task_eval/__pycache__/rag_utils.cpython-39.pyc and /dev/null differ diff --git a/task_eval/claude_utils.py b/task_eval/claude_utils.py index 44f964a..f14628d 100644 --- a/task_eval/claude_utils.py +++ b/task_eval/claude_utils.py @@ -151,13 +151,23 @@ def get_claude_answers(in_data, out_data, prediction_key, args): if qa['category'] == 2: questions.append(qa['question'] + ' Use DATE of CONVERSATION to answer with an approximate date.') elif qa['category'] == 5: + # Check for both 'answer' and 'adversarial_answer' keys + answer_text = None + if 'answer' in qa: + answer_text = qa['answer'] + elif 'adversarial_answer' in qa: + answer_text = qa['adversarial_answer'] + else: + print(f"Warning: Missing 'answer' or 'adversarial_answer' key in QA item: {qa}") + continue + question = qa['question'] + " Select the correct answer: (a) {} (b) {}. " if random.random() < 0.5: - question = question.format('Not mentioned in the conversation', qa['answer']) - answer = {'a': 'Not mentioned in the conversation', 'b': qa['answer']} + question = question.format('Not mentioned in the conversation', answer_text) + answer = {'a': 'Not mentioned in the conversation', 'b': answer_text} else: - question = question.format(qa['answer'], 'Not mentioned in the conversation') - answer = {'b': 'Not mentioned in the conversation', 'a': qa['answer']} + question = question.format(answer_text, 'Not mentioned in the conversation') + answer = {'b': 'Not mentioned in the conversation', 'a': answer_text} cat_5_idxs.append(len(questions)) questions.append(question) diff --git a/task_eval/evaluate_qa.py b/task_eval/evaluate_qa.py index c3e888c..e0c34b9 100644 --- a/task_eval/evaluate_qa.py +++ b/task_eval/evaluate_qa.py @@ -5,7 +5,7 @@ import os, json from tqdm import tqdm import argparse -from global_methods import set_openai_key, set_anthropic_key, set_gemini_key +from global_methods import set_openai_key, set_anthropic_key, set_gemini_key, get_gemini_client from task_eval.evaluation import eval_question_answering from task_eval.evaluation_stats import analyze_aggr_acc from task_eval.gpt_utils import get_gpt_answers @@ -14,7 +14,15 @@ from task_eval.hf_llm_utils import init_hf_model, get_hf_answers import numpy as np -import google.generativeai as genai + +# Category mapping for QA evaluation +CATEGORY_MAPPING = { + 1: "Multi-hop", + 2: "Temporal", + 3: "Open-domain", + 4: "Single-hop", + 5: "Adversarial" +} def parse_args(): @@ -50,12 +58,12 @@ def main(): set_anthropic_key() elif 'gemini' in args.model: - # set openai API key - set_gemini_key() + # Get Gemini client + gemini_client = get_gemini_client() + # Map old model names to new ones if needed if args.model == "gemini-pro-1.0": - model_name = "models/gemini-1.0-pro-latest" - - gemini_model = genai.GenerativeModel(model_name) + print("Warning: gemini-pro-1.0 is deprecated. Using gemini-2.5-pro instead.") + args.model = "gemini-2.5-pro" elif any([model_name in args.model for model_name in ['gemma', 'llama', 'mistral']]): hf_pipeline, hf_model_name = init_hf_model(args) @@ -89,7 +97,7 @@ def main(): elif 'claude' in args.model: answers = get_claude_answers(data, out_data, prediction_key, args) elif 'gemini' in args.model: - answers = get_gemini_answers(gemini_model, data, out_data, prediction_key, args) + answers = get_gemini_answers(gemini_client, data, out_data, prediction_key, args) elif any([model_name in args.model for model_name in ['gemma', 'llama', 'mistral']]): answers = get_hf_answers(data, out_data, args, hf_pipeline, hf_model_name) else: @@ -101,6 +109,10 @@ def main(): answers['qa'][i][model_key + '_f1'] = round(exact_matches[i], 3) if args.use_rag and len(recall) > 0: answers['qa'][i][model_key + '_recall'] = round(recall[i], 3) + + # Add category name to output + category_num = answers['qa'][i].get('category', 0) + answers['qa'][i]['category_name'] = CATEGORY_MAPPING.get(category_num, f"Unknown-{category_num}") out_samples[data['sample_id']] = answers diff --git a/task_eval/evaluation.py b/task_eval/evaluation.py index 8f597dd..9c8450a 100644 --- a/task_eval/evaluation.py +++ b/task_eval/evaluation.py @@ -12,6 +12,15 @@ LENGTH_THRESHOLD = 5 +# Category mapping for QA evaluation +CATEGORY_MAPPING = { + 1: "Multi-hop", + 2: "Temporal", + 3: "Open-domain", + 4: "Single-hop", + 5: "Adversarial" +} + class SimpleTokenizer(object): ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' NON_WS = r'[^\p{Z}\p{C}]' @@ -197,9 +206,23 @@ def eval_question_answering(qas, eval_key='prediction', metric='f1'): for i, line in enumerate(qas): # line = json.loads(line) if type(line[eval_key]) == list: - answer = line['answer'] + # Check for both 'answer' and 'adversarial_answer' keys + if 'answer' in line: + answer = line['answer'] + elif 'adversarial_answer' in line: + answer = line['adversarial_answer'] + else: + print(f"Warning: Missing answer key in evaluation line: {line}") + continue else: - answer = str(line['answer']) + # Check for both 'answer' and 'adversarial_answer' keys + if 'answer' in line: + answer = str(line['answer']) + elif 'adversarial_answer' in line: + answer = str(line['adversarial_answer']) + else: + print(f"Warning: Missing answer key in evaluation line: {line}") + continue if line['category'] == 3: answer = answer.split(';')[0].strip() diff --git a/task_eval/evaluation_stats.py b/task_eval/evaluation_stats.py index 3283edf..c8af2e3 100644 --- a/task_eval/evaluation_stats.py +++ b/task_eval/evaluation_stats.py @@ -3,6 +3,15 @@ from tqdm import tqdm from collections import defaultdict +# Category mapping for QA evaluation +CATEGORY_MAPPING = { + 1: "Multi-hop", + 2: "Temporal", + 3: "Open-domain", + 4: "Single-hop", + 5: "Adversarial" +} + def get_conversation_lengths(data, encoder=None): @@ -92,20 +101,25 @@ def analyze_aggr_acc(ann_file, in_file, out_file, model_name, metric_key, encode print("Total number of questions and corresponding accuracy in each category: ") + print("Category | Name | Count | Correct | Accuracy") + print("-" * 55) total_k = 0 total_v = 0 # for k, v in total_counts.items(): keys = [4, 1, 2, 3, 5] for k in keys: v = total_counts[k] + category_name = CATEGORY_MAPPING.get(k, f"Unknown-{k}") if float(v) == 0.0: - print("No questions found in category %s" % k) + print(f"No questions found in category {k} ({category_name})") else: - print(k, v, acc_counts[k], round(float(acc_counts[k])/v, 3)) + accuracy = round(float(acc_counts[k])/v, 3) + print(f"{k:8} | {category_name:11} | {v:5} | {acc_counts[k]:7.1f} | {accuracy:8.3f}") total_v += acc_counts[k] total_k += v - print("Overall accuracy: ", round(float(total_v)/total_k, 3)) + print("-" * 55) + print(f"Overall accuracy: {round(float(total_v)/total_k, 3):.3f}") # print("Total number of questions and corresponding accuracy by memory") # keys = list(memory_counts_og.keys()) @@ -123,19 +137,40 @@ def analyze_aggr_acc(ann_file, in_file, out_file, model_name, metric_key, encode results_dict[model_name] = {} results_dict[model_name]['category_counts'] = total_counts results_dict[model_name]['cum_accuracy_by_category'] = acc_counts + + # Add category names mapping to output + results_dict[model_name]['category_mapping'] = CATEGORY_MAPPING + + # Add category statistics with names + category_stats = {} + for k in [4, 1, 2, 3, 5]: + if k in total_counts and total_counts[k] > 0: + category_stats[k] = { + 'name': CATEGORY_MAPPING.get(k, f"Unknown-{k}"), + 'count': total_counts[k], + 'correct': acc_counts[k], + 'accuracy': round(float(acc_counts[k])/total_counts[k], 3) + } + results_dict[model_name]['category_statistics'] = category_stats if rag: results_dict[model_name]['recall_by_category'] = {k: v/total_counts[k] for k, v in recall_by_category.items()} print("Category and corresponding recall accuracy in each category: ") + print("Category | Name | Recall") + print("-" * 35) # for k, v in recall_by_category.items(): keys = [4, 1, 2, 3, 5] for k in keys: v = recall_by_category[k] + category_name = CATEGORY_MAPPING.get(k, f"Unknown-{k}") if float(total_counts[k]) == 0.0: - print("No questions found in category %s" % k) + print(f"No questions found in category {k} ({category_name})") else: - print(k, round(float(v)/total_counts[k], 3)) - print("Overall recall accuracy: ", sum(list(recall_by_category.values()))/sum(list(total_counts.values()))) + recall_acc = round(float(v)/total_counts[k], 3) + print(f"{k:8} | {category_name:11} | {recall_acc:6.3f}") + print("-" * 35) + overall_recall = sum(list(recall_by_category.values()))/sum(list(total_counts.values())) + print(f"Overall recall accuracy: {overall_recall:.3f}") else: results_dict[model_name]['category_counts_by_memory'] = memory_counts_og results_dict[model_name]['cum_accuracy_by_category_by_memory'] = memory_counts diff --git a/task_eval/gemini_utils.py b/task_eval/gemini_utils.py index 99204db..d6b1625 100644 --- a/task_eval/gemini_utils.py +++ b/task_eval/gemini_utils.py @@ -6,10 +6,10 @@ import os, json from tqdm import tqdm import time -from global_methods import run_gemini +from global_methods import run_gemini, get_gemini_client -MAX_LENGTH={'gemini-pro-1.0': 1000000} +MAX_LENGTH={'gemini-2.0-flash': 1000000, 'gemini-2.5-flash': 1000000, 'gemini-2.5-pro': 1000000} PER_QA_TOKEN_BUDGET = 50 QA_PROMPT = """ @@ -53,6 +53,10 @@ def process_ouput(text): if v is None: answers[k] = "" continue + # Convert to string if it's not already a string + if not isinstance(v, str): + v = str(v) + answers[k] = v if v.startswith('{') and v.endswith('}'): try: answers[k] = json.loads(v)['answer'] @@ -64,6 +68,11 @@ def process_ouput(text): for k, v in enumerate(answers): if v is None: answers[k] = "" + continue + # Convert to string if it's not already a string + if not isinstance(v, str): + v = str(v) + answers[k] = v if v.startswith('{') and v.endswith('}'): try: answers[k] = json.loads(v)['answer'] @@ -87,7 +96,7 @@ def get_cat_5_answer(model_prediction, answer_key): else: return model_prediction -def get_input_context(data, num_question_tokens, model, args): +def get_input_context(data, num_question_tokens, client, args): query_conv = '' min_session = -1 @@ -103,9 +112,9 @@ def get_input_context(data, num_question_tokens, model, args): turn += ' and shared %s.' % dialog["blip_caption"] turn += '\n' - # num_tokens = model.count_tokens('DATE: ' + data['session_%s_date_time' % i] + '\n' + 'CONVERSATION:\n' + turn).total_tokens + # num_tokens = client.models.count_tokens(model=args.model, contents='DATE: ' + data['session_%s_date_time' % i] + '\n' + 'CONVERSATION:\n' + turn).total_tokens - # if (num_tokens + model.count_tokens(query_conv).total_tokens + num_question_tokens) < (MAX_LENGTH[args.model]-(PER_QA_TOKEN_BUDGET*(args.batch_size))): # 20 tokens assigned for answers + # if (num_tokens + client.models.count_tokens(model=args.model, contents=query_conv).total_tokens + num_question_tokens) < (MAX_LENGTH[args.model]-(PER_QA_TOKEN_BUDGET*(args.batch_size))): # 20 tokens assigned for answers # query_conv = turn + query_conv # else: # min_session = i @@ -119,14 +128,14 @@ def get_input_context(data, num_question_tokens, model, args): break # if min_session == -1: - # print("Saved %s tokens in query conversation from full conversation" % model.count_tokens(query_conv).total_tokens) + # print("Saved %s tokens in query conversation from full conversation" % client.models.count_tokens(model=args.model, contents=query_conv).total_tokens) # else: - # print("Saved %s conv. tokens + %s question tokens in query from %s out of %s sessions" % (model.count_tokens(query_conv).total_tokens, num_question_tokens, max_session-min_session, max_session)) + # print("Saved %s conv. tokens + %s question tokens in query from %s out of %s sessions" % (client.models.count_tokens(model=args.model, contents=query_conv).total_tokens, num_question_tokens, max_session-min_session, max_session)) return query_conv -def get_gemini_answers(model, in_data, out_data, prediction_key, args): +def get_gemini_answers(client, in_data, out_data, prediction_key, args): assert len(in_data['qa']) == len(out_data['qa']), (len(in_data['qa']), len(out_data['qa'])) @@ -134,7 +143,7 @@ def get_gemini_answers(model, in_data, out_data, prediction_key, args): # start instruction prompt speakers_names = list(set([d['speaker'] for d in in_data['conversation']['session_1']])) start_prompt = CONV_START_PROMPT.format(speakers_names[0], speakers_names[1]) - # start_tokens = model.count_tokens(start_prompt).total_tokens + # start_tokens = client.models.count_tokens(model=args.model, contents=start_prompt).total_tokens start_tokens = 100 if args.rag_mode: @@ -164,13 +173,25 @@ def get_gemini_answers(model, in_data, out_data, prediction_key, args): if qa['category'] == 2: questions.append(qa['question'] + ' Use DATE of CONVERSATION to answer with an approximate date.') elif qa['category'] == 5: + # Check for both 'answer' and 'adversarial_answer' keys + answer_text = None + if 'answer' in qa: + answer_text = qa['answer'] + elif 'adversarial_answer' in qa: + answer_text = qa['adversarial_answer'] + else: + print(f"Warning: Missing 'answer' or 'adversarial_answer' key in QA item: {qa}") + print(f"Available keys: {list(qa.keys())}") + # Skip this question if no answer is available + continue + question = qa['question'] + " Select the correct answer: (a) {} (b) {}. " if random.random() < 0.5: - question = question.format('Not mentioned in the conversation', qa['answer']) - answer = {'a': 'Not mentioned in the conversation', 'b': qa['answer']} + question = question.format('Not mentioned in the conversation', answer_text) + answer = {'a': 'Not mentioned in the conversation', 'b': answer_text} else: - question = question.format(qa['answer'], 'Not mentioned in the conversation') - answer = {'b': 'Not mentioned in the conversation', 'a': qa['answer']} + question = question.format(answer_text, 'Not mentioned in the conversation') + answer = {'b': 'Not mentioned in the conversation', 'a': answer_text} cat_5_idxs.append(len(questions)) questions.append(question) @@ -190,22 +211,22 @@ def get_gemini_answers(model, in_data, out_data, prediction_key, args): raise NotImplementedError else: question_prompt = QA_PROMPT_BATCH + "\n".join(["%s: %s" % (k, q) for k, q in enumerate(questions)]) - num_question_tokens = model.count_tokens(question_prompt).total_tokens + num_question_tokens = client.models.count_tokens(model=args.model, contents=question_prompt).total_tokens num_question_tokens = 200 - query_conv = get_input_context(in_data['conversation'], num_question_tokens + start_tokens, model, args) + query_conv = get_input_context(in_data['conversation'], num_question_tokens + start_tokens, client, args) query_conv = start_prompt + query_conv - # print("%s tokens in query" % model.count_tokens(query_conv).total_tokens) + # print("%s tokens in query" % client.models.count_tokens(model=args.model, contents=query_conv).total_tokens) - if 'pro-1.0' in args.model: + if 'pro' in args.model: time.sleep(30) if args.batch_size == 1: query = query_conv + '\n\n' + QA_PROMPT.format(questions[0]) if len(cat_5_idxs) == 0 else query_conv + '\n\n' + QA_PROMPT_CAT_5.format(questions[0]) - answer = run_gemini(model, query) + answer = run_gemini(client, args.model, query) if len(cat_5_idxs) > 0: answer = get_cat_5_answer(answer, cat_5_answers[0]) @@ -225,9 +246,9 @@ def get_gemini_answers(model, in_data, out_data, prediction_key, args): try: trials += 1 # print("Trial %s" % trials) - # print("Sending query of %s tokens" % model.count_tokens(query).total_tokens) + # print("Sending query of %s tokens" % client.models.count_tokens(model=args.model, contents=query).total_tokens) # print("Trying with answer token budget = %s per question" % PER_QA_TOKEN_BUDGET) - answer = run_gemini(model, query) + answer = run_gemini(client, args.model, query) answer = answer.replace('\\"', "'").replace('json','').replace('`','').strip() # try: diff --git a/task_eval/gpt_utils.py b/task_eval/gpt_utils.py index 64d65df..2aadc33 100644 --- a/task_eval/gpt_utils.py +++ b/task_eval/gpt_utils.py @@ -243,13 +243,23 @@ def get_gpt_answers(in_data, out_data, prediction_key, args): if qa['category'] == 2: questions.append(qa['question'] + ' Use DATE of CONVERSATION to answer with an approximate date.') elif qa['category'] == 5: + # Check for both 'answer' and 'adversarial_answer' keys + answer_text = None + if 'answer' in qa: + answer_text = qa['answer'] + elif 'adversarial_answer' in qa: + answer_text = qa['adversarial_answer'] + else: + print(f"Warning: Missing 'answer' or 'adversarial_answer' key in QA item: {qa}") + continue + question = qa['question'] + " Select the correct answer: (a) {} (b) {}. " if random.random() < 0.5: - question = question.format('Not mentioned in the conversation', qa['answer']) - answer = {'a': 'Not mentioned in the conversation', 'b': qa['answer']} + question = question.format('Not mentioned in the conversation', answer_text) + answer = {'a': 'Not mentioned in the conversation', 'b': answer_text} else: - question = question.format(qa['answer'], 'Not mentioned in the conversation') - answer = {'b': 'Not mentioned in the conversation', 'a': qa['answer']} + question = question.format(answer_text, 'Not mentioned in the conversation') + answer = {'b': 'Not mentioned in the conversation', 'a': answer_text} cat_5_idxs.append(len(questions)) questions.append(question) diff --git a/task_eval/hf_llm_utils.py b/task_eval/hf_llm_utils.py index 7d16cf6..3f2afbf 100644 --- a/task_eval/hf_llm_utils.py +++ b/task_eval/hf_llm_utils.py @@ -252,13 +252,23 @@ def get_hf_answers(in_data, out_data, args, pipeline, model_name): if qa['category'] == 2: questions.append(qa['question'] + ' Use DATE of CONVERSATION to answer with an approximate date.') elif qa['category'] == 5: + # Check for both 'answer' and 'adversarial_answer' keys + answer_text = None + if 'answer' in qa: + answer_text = qa['answer'] + elif 'adversarial_answer' in qa: + answer_text = qa['adversarial_answer'] + else: + print(f"Warning: Missing 'answer' or 'adversarial_answer' key in QA item: {qa}") + continue + question = qa['question'] + " (a) {} (b) {}. Select the correct answer by writing (a) or (b)." if random.random() < 0.5: - question = question.format('No information available', qa['answer']) - answer = {'a': 'No information available', 'b': qa['answer']} + question = question.format('No information available', answer_text) + answer = {'a': 'No information available', 'b': answer_text} else: - question = question.format(qa['answer'], 'No information available') - answer = {'b': 'No information available', 'a': qa['answer']} + question = question.format(answer_text, 'No information available') + answer = {'b': 'No information available', 'a': answer_text} cat_5_idxs.append(len(questions)) questions.append(question) cat_5_answers.append(answer)