diff --git a/py/ai.py b/py/ai.py index 0a57b57..bfde713 100644 --- a/py/ai.py +++ b/py/ai.py @@ -1,4 +1,5 @@ from mistralai import Mistral +from openai import OpenAI import ollama @@ -12,15 +13,15 @@ class AI: options={"temperature": 0.5}, ) - return_class.ai_response[access_token] = "" + with return_class.ai_response_lock: + return_class.ai_response[access_token] = "" for chunk in stream: - return_class.ai_response[access_token] += chunk['message']['content'] + with return_class.ai_response_lock: + return_class.ai_response[access_token] += chunk['message']['content'] @staticmethod - def process_mistralai(model, messages, return_class, access_token): - with open("api_key.txt", 'r') as f: - api_key = f.read().strip() + def process_mistralai(model, messages, return_class, access_token, api_key): client = Mistral(api_key=api_key) @@ -29,7 +30,26 @@ class AI: messages=messages ) - return_class.ai_response[access_token] = "" + with return_class.ai_response_lock: + return_class.ai_response[access_token] = "" for chunk in stream_response: - return_class.ai_response[access_token] += chunk.data.choices[0].delta.content + with return_class.ai_response_lock: + return_class.ai_response[access_token] += chunk.data.choices[0].delta.content + + @staticmethod + def process_openai(model, messages, return_class, access_token, api_key): + + client = OpenAI(api_key=api_key) + + stream_response = client.chat.completions.create( + model=model, + messages=messages + ) + + with return_class.ai_response_lock: + return_class.ai_response[access_token] = "" + + for chunk in stream_response: + with return_class.ai_response_lock: + return_class.ai_response[access_token] += chunk.choices[0].delta.content \ No newline at end of file diff --git a/py/api.py b/py/api.py index b52870f..386cd4f 100644 --- a/py/api.py +++ b/py/api.py @@ -1,6 +1,7 @@ from flask import Flask, request, jsonify from flask_cors import CORS import secrets +import threading from ai import AI from db import DB from OpenSSL import crypto @@ -8,12 +9,13 @@ from OpenSSL import crypto class API: def __init__(self): - self.crypt_size = 4096 + self.crypt_size = 1 self.app = Flask(__name__) self.ai_response = {} self.ai = AI() self.db = DB() self.db.load_database() + self.ai_response_lock = threading.Lock() CORS(self.app) def run(self): @@ -34,10 +36,18 @@ class API: return jsonify({'status': 401, 'error': 'Invalid access token'}) if model_type == "local": - self.ai.process_local(ai_model, messages, self, access_token) - if model_type == "mistral": - self.ai.process_mistralai(ai_model, messages, self, access_token) - return jsonify({'status': 200}) + thread = threading.Thread(target=self.ai.process_local, args=(ai_model, messages, self, access_token)) + thread.start() + thread.join() + return jsonify({'status': 200}) + elif model_type == "mistral": + api_key = data.get('api_key') + thread = threading.Thread(target=self.ai.process_mistralai, args=(ai_model, messages, self, access_token, api_key)) + thread.start() + thread.join() + return jsonify({'status': 200}) + + return jsonify({'status': 401, 'error': 'Invalid AI model type'}) @self.app.route('/interstellar/api/ai_get', methods=['GET']) def get_ai(): @@ -95,7 +105,7 @@ class API: f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) ssl_context = ("cert.pem", "key.pem") - self.app.run(debug=True, host='0.0.0.0', port=5000, ssl_context=ssl_context) + self.app.run(debug=True, host='0.0.0.0', port=5000) if __name__ == '__main__': diff --git a/py/requirements.txt b/py/requirements.txt index 144c571..3c2be3b 100644 --- a/py/requirements.txt +++ b/py/requirements.txt @@ -2,4 +2,5 @@ flask flask-cors ollama mistralai +openai pyOpenSSL \ No newline at end of file