Skip to content

Add Mistral notebook with OCR + Small 3.1 #3571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions sdk/python/foundation-models/mistral/mistralai-ocr.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install python-dotenv azure.identity azure-ai-inference"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from typing import Any\n",
"import json\n",
"import base64\n",
"from azure.ai.inference import ChatCompletionsClient\n",
"from azure.ai.inference.models import (\n",
" SystemMessage,\n",
" UserMessage,\n",
" TextContentItem,\n",
" ImageContentItem,\n",
" ImageUrl,\n",
" ImageDetailLevel,\n",
")\n",
"from azure.core.credentials import AzureKeyCredential\n",
"import urllib"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def encode_file(image_path):\n",
" \"\"\"Encode the image to base64.\"\"\"\n",
" try:\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
" except FileNotFoundError:\n",
" print(f\"Error: The file {image_path} was not found.\")\n",
" return None\n",
" except Exception as e: # Added general exception handling\n",
" print(f\"Error: {e}\")\n",
" return None\n",
"\n",
"\n",
"def ocr_analysis(local_path: str, url: str, ocr_api_key: str) -> Any:\n",
" base64_string = encode_file(local_path)\n",
"\n",
" if local_path.split(\".\")[-1] == \"pdf\":\n",
" data = {\n",
" \"model\": \"mistral-ocr-latest\", # Replace with the appropriate model name\n",
" \"include_image_base64\": \"true\",\n",
" \"document\": {\n",
" \"type\": \"document_url\",\n",
" \"document_url\": f\"data:image/jpeg;base64,{base64_string}\",\n",
" },\n",
" }\n",
" else:\n",
" data = {\n",
" \"model\": \"mistral-ocr-latest\", # Replace with the appropriate model name,\n",
" \"include_image_base64\": \"true\",\n",
" \"document\": {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": f\"data:image/jpeg;base64,{base64_string}\",\n",
" },\n",
" }\n",
" headers = {\n",
" \"Content-Type\": \"application/json\",\n",
" \"Accept\": \"application/json\",\n",
" \"Authorization\": (\"Bearer \" + ocr_api_key),\n",
" }\n",
" body = str.encode(json.dumps(data))\n",
" req = urllib.request.Request(url, body, headers)\n",
"\n",
" try:\n",
" response = urllib.request.urlopen(req)\n",
" result = response.read().decode(\"utf-8\") # Decode the response to a string\n",
" json_result = json.loads(result) # Parse the string into a JSON object\n",
" return json_result\n",
" # print(json.dumps(json_result, indent=4)) # Pretty-print the JSON object\n",
" except urllib.error.HTTPError as error:\n",
" print(\"The request failed with status code: \" + str(error.code))\n",
" # Print the headers - they include the request ID and the timestamp, which are useful for debugging the failure\n",
" print(error.info())\n",
" print(error.read().decode(\"utf8\", \"ignore\"))\n",
" except json.JSONDecodeError:\n",
" print(\"Failed to decode JSON from the response.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"OCR output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"https://you-ocr-deployment-name.region.models.ai.azure.com/v1/ocr\" # copy and paste this from your deployment\n",
"# Path to your image\n",
"image_path = \"/path-to-your-image/image.png\"\n",
"# Replace this with the primary/secondary key, AMLToken, or Microsoft Entra ID token for the endpoint\n",
"pdf_path = \"/path-too-your-pdf/sample.pdf\"\n",
"ocr_api_key = \"api-key-for-ocr-deployment\"\n",
"ocr_output = ocr_analysis(image_path, url, ocr_api_key)\n",
"ocr_output"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ocr_markdown = ocr_output[\"pages\"][0][\"markdown\"]\n",
"ocr_markdown"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pairing this with a vision language model\n",
"\n",
"you can extract markdown from any page/ image you've done OCR on and pass this markdown along with the image itself for further analysis. You can get this output as structured output. For example, we pass the first page from the above document to a general purpose vision language model as follows: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vlm_endpoint = \"AZURE_INFERENCE_SDK_ENDPOINT\" # something like 'https://xxxxxxxx.services.ai.azure.com/models'\n",
"vlm_api_key = \"API key for Vision Language Model deployment\"\n",
"local_path = \"/path-to-your-image/image.png\"\n",
"\n",
"\n",
"def vlm_augmentation(\n",
" local_path: str, vlm_endpoint: str, vlm_api_key: str, ocr_markdown: str\n",
") -> str:\n",
"\n",
" model_deployment = \"mistral-small-2503\" # or other specific mistral VLM you prefer\n",
"\n",
" client = ChatCompletionsClient(\n",
" endpoint=vlm_endpoint,\n",
" credential=AzureKeyCredential(vlm_api_key),\n",
" headers={\"azureml-model-deployment\": model_deployment},\n",
" )\n",
"\n",
" response = client.complete(\n",
" messages=[\n",
" SystemMessage(\n",
" \"You are an AI assistant that describes images in details. You will be given the markdown representation of the text in the image to provide context for your analysis\"\n",
" ),\n",
" UserMessage(\n",
" [\n",
" TextContentItem(\n",
" text=f\"This is image with OCR output markdown:\\n\\n{ocr_markdown}\\n.\\nAnalyze this image and provide the following in JSON format. For each image: 1) A concise description 2) Main objects/elements present 3) Any text visible in the image 4) Estimated image type (photo, diagram, chart, etc.) 5) If it's a plot, infographic or chart provide the data points and the type of plot, infographic or chart. The output should be strictly be json with no extra commentary\"\n",
" ),\n",
" ImageContentItem(\n",
" image_url=ImageUrl.load(\n",
" image_file=local_path,\n",
" image_format=\"png\",\n",
" detail=ImageDetailLevel.HIGH,\n",
" ),\n",
" ),\n",
" ],\n",
" ),\n",
" ],\n",
" model=model_deployment,\n",
" )\n",
" json_string = response.choices[0].message.content\n",
" # Remove the markdown code block markers if they exist\n",
" if json_string.startswith(\"```json\"):\n",
" json_string = json_string[7:] # Remove ```json\n",
" if json_string.endswith(\"```\"):\n",
" json_string = json_string[:-3] # Remove ```\n",
"\n",
" # Parse the JSON string into a Python object\n",
" json_object = json.loads(json_string)\n",
" return json_object"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = vlm_augmentation(local_path, vlm_endpoint, vlm_api_key, ocr_markdown)\n",
"print(json.dumps(output, indent=4))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}