Merge pull request 'Backend optimisation' (#16) from React-Group/ai-virtual-assistant:main into main
Reviewed-on: https://interstellardevelopment.org/code/code/sageTheDm/ai-virtual-assistant/pulls/16
This commit is contained in:
commit
94660381ad
5 changed files with 104 additions and 31 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,4 +1,5 @@
|
||||||
venv/
|
venv/
|
||||||
__pycache__
|
__pycache__/
|
||||||
.idea/
|
.idea/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
token.txt
|
||||||
|
|
62
py/api.py
62
py/api.py
|
@ -1,27 +1,52 @@
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
from transformers import AutoTokenizer, LlamaForCausalLM
|
from gradio_client import Client
|
||||||
|
import os
|
||||||
|
from mistralai import Mistral
|
||||||
|
|
||||||
class API:
|
class API:
|
||||||
# This method processes a message via transformers. (NOT FINISHED!)
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_text_transformers(prompt, model):
|
def process_text_mistralai(prompt, model, system):
|
||||||
model = LlamaForCausalLM.from_pretrained(model)
|
with open("token.txt", "r") as f:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
token = f.readlines()[0].strip()
|
||||||
|
|
||||||
inputs = tokenizer(prompt, return_tensors="pt")
|
api_key = token
|
||||||
|
|
||||||
generate_ids = model.generate(inputs.input_ids, max_length=30)
|
client = Mistral(api_key=api_key)
|
||||||
return tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
||||||
|
chat_response = client.chat.complete(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}, {
|
||||||
|
"role": "system",
|
||||||
|
"content": system,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return chat_response.choices[0].message.content
|
||||||
|
@staticmethod
|
||||||
|
def process_text_gradio(prompt, model, system):
|
||||||
|
client = Client(model)
|
||||||
|
result = client.predict(
|
||||||
|
message=prompt,
|
||||||
|
system_message=system,
|
||||||
|
max_tokens=512,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95,
|
||||||
|
api_name="/chat"
|
||||||
|
)
|
||||||
|
return result;
|
||||||
|
|
||||||
# This method processes a message via ollama
|
# This method processes a message via ollama
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_text_local(prompt, model):
|
def process_text_local(prompt, model, system):
|
||||||
ollama_url = "http://localhost:11434"
|
ollama_url = "http://localhost:11434"
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{ollama_url}/api/generate", json={"model": model, "prompt": prompt}
|
f"{ollama_url}/api/generate", json={"model": model, "prompt": prompt, "system": system}
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
@ -37,17 +62,20 @@ class API:
|
||||||
return "Error: " + response.text
|
return "Error: " + response.text
|
||||||
|
|
||||||
# This method sends a message to a certain AI.
|
# This method sends a message to a certain AI.
|
||||||
def send_message(self, message, model):
|
|
||||||
|
def send_message(self, message, model, system):
|
||||||
if model == 1:
|
if model == 1:
|
||||||
answer = self.process_text_local(message, "phi3.5")
|
answer = self.process_text_local(message, "phi3.5", system)
|
||||||
elif model == 2:
|
elif model == 2:
|
||||||
answer = self.process_text_local(message, "gemma2:2b")
|
answer = self.process_text_local(message, "gemma2:9b", system)
|
||||||
elif model == 3:
|
elif model == 3:
|
||||||
answer = self.process_text_local(message, "qwen2:0.5b")
|
answer = self.process_text_local(message, "codegemma:2b", system)
|
||||||
elif model == 4:
|
elif model == 4:
|
||||||
answer = self.process_text_local(message, "codegemma:2b")
|
answer = self.process_text_gradio(message, "PatrickPluto/InterstellarAIChatbot", system)
|
||||||
elif model == 5:
|
elif model == 5:
|
||||||
answer = self.process_text_transformers(message, "meta-llama/Meta-Llama-3.1-8B")
|
answer = self.process_text_mistralai(message, "mistral-large-latest", system)
|
||||||
|
elif model == 6:
|
||||||
|
answer = self.process_text_mistralai(message, "codestral-latest", system)
|
||||||
else:
|
else:
|
||||||
return "Invalid choice"
|
return "Invalid choice"
|
||||||
return answer
|
return answer
|
||||||
|
|
17
py/install.sh
Executable file
17
py/install.sh
Executable file
|
@ -0,0 +1,17 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
python3 -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install flask
|
||||||
|
pip install SpeechRecognition
|
||||||
|
pip install pyaudio
|
||||||
|
pip install pocketsphinx
|
||||||
|
pip install sentencepiece
|
||||||
|
pip install pyqt5
|
||||||
|
pip install pyqtwebengine
|
||||||
|
pip install gradio_client
|
||||||
|
pip install mistralai
|
||||||
|
|
||||||
|
ollama pull phi3.5
|
||||||
|
ollama pull codegemma:2b
|
||||||
|
ollama pull gemma2:9b
|
10
py/venv.sh
10
py/venv.sh
|
@ -1,10 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
virtualenv venv
|
|
||||||
source venv/bin/activate
|
|
||||||
pip install transformers
|
|
||||||
pip install torch
|
|
||||||
pip install flask
|
|
||||||
pip install SpeechRecognition
|
|
||||||
pip install pyaudio
|
|
||||||
pip install pocketsphinx
|
|
43
py/web_flask.py
Normal file → Executable file
43
py/web_flask.py
Normal file → Executable file
|
@ -1,6 +1,14 @@
|
||||||
|
#!venv/bin/python
|
||||||
|
|
||||||
from flask import Flask, request, render_template
|
from flask import Flask, request, render_template
|
||||||
from api import API
|
from api import API
|
||||||
from voice_recognition import Voice
|
from voice_recognition import Voice
|
||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
from PyQt5.QtCore import *
|
||||||
|
from PyQt5.QtWebEngineWidgets import *
|
||||||
|
from PyQt5.QtWidgets import *
|
||||||
|
|
||||||
|
|
||||||
APP = Flask(__name__)
|
APP = Flask(__name__)
|
||||||
api = API()
|
api = API()
|
||||||
|
@ -13,6 +21,8 @@ messages = []
|
||||||
@APP.route('/', methods=['GET', 'POST'])
|
@APP.route('/', methods=['GET', 'POST'])
|
||||||
def index():
|
def index():
|
||||||
global messages
|
global messages
|
||||||
|
system_prompt = 'You are a helpful assistant.'
|
||||||
|
system = 'Your system prompt is: \"'+system_prompt+'\" The following is your chat log so far: \n'
|
||||||
|
|
||||||
if request.method == 'POST':
|
if request.method == 'POST':
|
||||||
option = request.form['option']
|
option = request.form['option']
|
||||||
|
@ -25,7 +35,12 @@ def index():
|
||||||
elif option == "chat":
|
elif option == "chat":
|
||||||
messages.append(f"User: {user_message}")
|
messages.append(f"User: {user_message}")
|
||||||
|
|
||||||
ai_response = "AI: " + api.send_message(user_message, 3)
|
for line in messages:
|
||||||
|
system += line + '\n'
|
||||||
|
|
||||||
|
system += "The chat log is now finished."
|
||||||
|
|
||||||
|
ai_response = "AI: " + api.send_message(user_message, 5, system)
|
||||||
messages.append(ai_response)
|
messages.append(ai_response)
|
||||||
|
|
||||||
return render_template('index.html', messages=messages)
|
return render_template('index.html', messages=messages)
|
||||||
|
@ -42,6 +57,28 @@ def contact():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
APP.run(debug=True)
|
qapp = QApplication(sys.argv)
|
||||||
|
|
||||||
# This is a comment --> test if this creates a merge conflict
|
view = QWebEngineView()
|
||||||
|
|
||||||
|
view.setGeometry(100, 100, 1280, 720)
|
||||||
|
view.setWindowTitle("InterstellarAI")
|
||||||
|
|
||||||
|
view.setUrl(QUrl("http://localhost:5000"))
|
||||||
|
|
||||||
|
view.show()
|
||||||
|
|
||||||
|
def run_flask():
|
||||||
|
APP.run()
|
||||||
|
|
||||||
|
def stop_flask():
|
||||||
|
thread.join()
|
||||||
|
qapp.quit()
|
||||||
|
|
||||||
|
thread = threading.Thread(target=run_flask)
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
qapp.aboutToQuit.connect(stop_flask)
|
||||||
|
|
||||||
|
sys.exit(qapp.exec_())
|
||||||
|
|
Loading…
Reference in a new issue