Health/Assets/_VoiceAssistant/AIChatTookit/Scripts/LLM/Baidu/ChatBaidu.cs

232 lines
6.4 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using Newtonsoft.Json;
using System;
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.Networking;
public class ChatBaidu : LLM
{
public ChatBaidu()
{
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant";
}
void Awake()
{
OnInit();
}
/// <summary>
/// token脚本
/// </summary>
[SerializeField] private BaiduSettings m_Settings;
/// <summary>
/// 历史对话
/// </summary>
private List<message> m_History = new List<message>();
/// <summary>
/// 选择的模型类型
/// </summary>
[Header("设置模型名称")]
public ModelType m_ModelType = ModelType.ERNIE_Bot_turbo;
/// <summary>
/// 初始化
/// </summary>
private void OnInit()
{
m_Settings = this.GetComponent<BaiduSettings>();
GetEndPointUrl();
}
/// <summary>
/// 发送消息
/// </summary>
/// <returns></returns>
public override void PostMsg(string _msg, Action<string> _callback)
{
base.PostMsg(_msg, _callback);
}
/// <summary>
/// 发送数据
/// </summary>
/// <param name="_postWord"></param>
/// <param name="_callback"></param>
/// <returns></returns>
public override IEnumerator Request(string _postWord, System.Action<string> _callback)
{
stopwatch.Restart();
string _postUrl = url + "?access_token=" + m_Settings.m_Token;
m_History.Add(new message("user", _postWord));
RequestData _postData = new RequestData
{
messages = m_History
};
using (UnityWebRequest request = new UnityWebRequest(_postUrl, "POST"))
{
string _jsonData = JsonUtility.ToJson(_postData);
byte[] data = System.Text.Encoding.UTF8.GetBytes(_jsonData);
request.uploadHandler = (UploadHandler)new UploadHandlerRaw(data);
request.downloadHandler = (DownloadHandler)new DownloadHandlerBuffer();
request.SetRequestHeader("Content-Type", "application/json");
yield return request.SendWebRequest();
if (request.responseCode == 200)
{
string _msg = request.downloadHandler.text;
request.Dispose();
ResponseData response = JsonConvert.DeserializeObject<ResponseData>(_msg);
//历史记录
string _responseText = response.result;
m_History.Add(new message("assistant", response.result));
//添加记录
m_DataList.Add(new SendData("assistant", response.result));
//回调
_callback(response.result);
}
else
request.Dispose();
}
stopwatch.Stop();
Debug.Log("chat百度-耗时:" + stopwatch.Elapsed.TotalSeconds);
}
/// <summary>
/// 获取资源路径
/// </summary>
private void GetEndPointUrl()
{
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + CheckModelType(m_ModelType);
}
/// <summary>
/// 获取资源
/// </summary>
/// <param name="_type"></param>
/// <returns></returns>
private string CheckModelType(ModelType _type)
{
if (_type == ModelType.ERNIE_Bot){
return "completions";
}
if (_type == ModelType.ERNIE_Bot_turbo)
{
return "eb-instant";
}
if (_type == ModelType.BLOOMZ_7B)
{
return "bloomz_7b1";
}
if (_type == ModelType.Qianfan_BLOOMZ_7B_compressed)
{
return "qianfan_bloomz_7b_compressed";
}
if (_type == ModelType.ChatGLM2_6B_32K)
{
return "chatglm2_6b_32k";
}
if (_type == ModelType.Llama_2_7B_Chat)
{
return "llama_2_7b";
}
if (_type == ModelType.Llama_2_13B_Chat)
{
return "llama_2_13b";
}
if (_type == ModelType.Llama_2_70B_Chat)
{
return "llama_2_70b";
}
if (_type == ModelType.Qianfan_Chinese_Llama_2_7B)
{
return "qianfan_chinese_llama_2_7b";
}
if (_type == ModelType.AquilaChat_7B)
{
return "aquilachat_7b";
}
return "";
}
#region
//发送的数据
[Serializable]
private class RequestData
{
public List<message> messages = new List<message>();//发送的消息
public bool stream = false;//是否流式输出
public string user_id=string.Empty;
}
[Serializable]
private class message
{
public string role=string.Empty;//角色
public string content = string.Empty;//对话内容
public message() { }
public message(string _role,string _content)
{
role = _role;
content = _content;
}
}
//接收的数据
[Serializable]
private class ResponseData
{
public string id = string.Empty;//本轮对话的id
public int created;
public int sentence_id;//表示当前子句的序号。只有在流式接口模式下会返回该字段
public bool is_end;//表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段
public bool is_truncated;//表示当前子句是否是最后一句。只有在流式接口模式下会返回该字段
public string result = string.Empty;//返回的文本
public bool need_clear_history;//表示用户输入是否存在安全
public int ban_round;//当need_clear_history为true时此字段会告知第几轮对话有敏感信息如果是当前问题ban_round=-1
public Usage usage = new Usage();//token统计信息token数 = 汉字数+单词数*1.3
}
[Serializable]
private class Usage
{
public int prompt_tokens;//问题tokens数
public int completion_tokens;//回答tokens数
public int total_tokens;//tokens总数
}
#endregion
/// <summary>
/// 模型名称
/// </summary>
public enum ModelType
{
ERNIE_Bot,
ERNIE_Bot_turbo,
BLOOMZ_7B,
Qianfan_BLOOMZ_7B_compressed,
ChatGLM2_6B_32K,
Llama_2_7B_Chat,
Llama_2_13B_Chat,
Llama_2_70B_Chat,
Qianfan_Chinese_Llama_2_7B,
AquilaChat_7B,
}
}