Multithreadding #8

Merged
sageTheDm merged 2 commits from React-Group/interstellar_ai:main into main 2024-09-23 11:56:31 +02:00
3 changed files with 44 additions and 13 deletions

View file

@ -1,4 +1,5 @@
from mistralai import Mistral from mistralai import Mistral
from openai import OpenAI
import ollama import ollama
@ -12,15 +13,15 @@ class AI:
options={"temperature": 0.5}, options={"temperature": 0.5},
) )
with return_class.ai_response_lock:
return_class.ai_response[access_token] = "" return_class.ai_response[access_token] = ""
for chunk in stream: for chunk in stream:
with return_class.ai_response_lock:
return_class.ai_response[access_token] += chunk['message']['content'] return_class.ai_response[access_token] += chunk['message']['content']
@staticmethod @staticmethod
def process_mistralai(model, messages, return_class, access_token): def process_mistralai(model, messages, return_class, access_token, api_key):
with open("api_key.txt", 'r') as f:
api_key = f.read().strip()
client = Mistral(api_key=api_key) client = Mistral(api_key=api_key)
@ -29,7 +30,26 @@ class AI:
messages=messages messages=messages
) )
with return_class.ai_response_lock:
return_class.ai_response[access_token] = "" return_class.ai_response[access_token] = ""
for chunk in stream_response: for chunk in stream_response:
with return_class.ai_response_lock:
return_class.ai_response[access_token] += chunk.data.choices[0].delta.content return_class.ai_response[access_token] += chunk.data.choices[0].delta.content
@staticmethod
def process_openai(model, messages, return_class, access_token, api_key):
client = OpenAI(api_key=api_key)
stream_response = client.chat.completions.create(
model=model,
messages=messages
)
with return_class.ai_response_lock:
return_class.ai_response[access_token] = ""
for chunk in stream_response:
with return_class.ai_response_lock:
return_class.ai_response[access_token] += chunk.choices[0].delta.content

View file

@ -1,6 +1,7 @@
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
from flask_cors import CORS from flask_cors import CORS
import secrets import secrets
import threading
from ai import AI from ai import AI
from db import DB from db import DB
from OpenSSL import crypto from OpenSSL import crypto
@ -8,12 +9,13 @@ from OpenSSL import crypto
class API: class API:
def __init__(self): def __init__(self):
self.crypt_size = 4096 self.crypt_size = 1
self.app = Flask(__name__) self.app = Flask(__name__)
self.ai_response = {} self.ai_response = {}
self.ai = AI() self.ai = AI()
self.db = DB() self.db = DB()
self.db.load_database() self.db.load_database()
self.ai_response_lock = threading.Lock()
CORS(self.app) CORS(self.app)
def run(self): def run(self):
@ -34,10 +36,18 @@ class API:
return jsonify({'status': 401, 'error': 'Invalid access token'}) return jsonify({'status': 401, 'error': 'Invalid access token'})
if model_type == "local": if model_type == "local":
self.ai.process_local(ai_model, messages, self, access_token) thread = threading.Thread(target=self.ai.process_local, args=(ai_model, messages, self, access_token))
if model_type == "mistral": thread.start()
self.ai.process_mistralai(ai_model, messages, self, access_token) thread.join()
return jsonify({'status': 200}) return jsonify({'status': 200})
elif model_type == "mistral":
api_key = data.get('api_key')
thread = threading.Thread(target=self.ai.process_mistralai, args=(ai_model, messages, self, access_token, api_key))
thread.start()
thread.join()
return jsonify({'status': 200})
return jsonify({'status': 401, 'error': 'Invalid AI model type'})
@self.app.route('/interstellar/api/ai_get', methods=['GET']) @self.app.route('/interstellar/api/ai_get', methods=['GET'])
def get_ai(): def get_ai():
@ -95,7 +105,7 @@ class API:
f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8")) f.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, k).decode("utf-8"))
ssl_context = ("cert.pem", "key.pem") ssl_context = ("cert.pem", "key.pem")
self.app.run(debug=True, host='0.0.0.0', port=5000, ssl_context=ssl_context) self.app.run(debug=True, host='0.0.0.0', port=5000)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -2,4 +2,5 @@ flask
flask-cors flask-cors
ollama ollama
mistralai mistralai
openai
pyOpenSSL pyOpenSSL