From 83e0614bdb9c7d62e12f0d134121548927b0fa35 Mon Sep 17 00:00:00 2001 From: Patrick_Pluto Date: Mon, 23 Sep 2024 11:57:16 +0200 Subject: [PATCH] More models --- py/ai.py | 23 +++++++++++++++++++++-- py/api.py | 12 ++++++++++++ py/requirements.txt | 1 + 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/py/ai.py b/py/ai.py index bfde713..a8c5537 100644 --- a/py/ai.py +++ b/py/ai.py @@ -1,5 +1,6 @@ from mistralai import Mistral from openai import OpenAI +import anthropic import ollama @@ -44,7 +45,8 @@ class AI: stream_response = client.chat.completions.create( model=model, - messages=messages + messages=messages, + stream=True ) with return_class.ai_response_lock: @@ -52,4 +54,21 @@ class AI: for chunk in stream_response: with return_class.ai_response_lock: - return_class.ai_response[access_token] += chunk.choices[0].delta.content \ No newline at end of file + return_class.ai_response[access_token] += chunk.choices[0].delta.content + + @staticmethod + def process_anthropic(model, messages, return_class, access_token, api_key): + + client = anthropic.Anthropic(api_key=api_key) + + with return_class.ai_response_lock: + return_class.ai_response[access_token] = "" + + with client.messages.stream( + max_tokens=1024, + model=model, + messages=messages, + ) as stream: + for text in stream.text_stream: + with return_class.ai_response_lock: + return_class.ai_response[access_token] += text diff --git a/py/api.py b/py/api.py index 386cd4f..55d0483 100644 --- a/py/api.py +++ b/py/api.py @@ -46,6 +46,18 @@ class API: thread.start() thread.join() return jsonify({'status': 200}) + elif model_type == "openai": + api_key = data.get('api_key') + thread = threading.Thread(target=self.ai.process_openai, args=(ai_model, messages, self, access_token, api_key)) + thread.start() + thread.join() + return jsonify({'status': 200}) + elif model_type == "anthropic": + api_key = data.get('api_key') + thread = threading.Thread(target=self.ai.process_anthropic, args=(ai_model, messages, self, access_token, api_key)) + thread.start() + thread.join() + return jsonify({'status': 200}) return jsonify({'status': 401, 'error': 'Invalid AI model type'}) diff --git a/py/requirements.txt b/py/requirements.txt index 3c2be3b..bd93ad7 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -3,4 +3,5 @@ flask-cors ollama mistralai openai +anthropic pyOpenSSL \ No newline at end of file -- 2.39.5